diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index 8a2aa66e702593c832ccf6c45dd3c2804cc14db9..c734413b7418a4d88e5e51188348a27526e861ba 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -54,28 +54,19 @@ class ModelWrapper: """Run given anaylsis_fxn on this model and return resulting dict.""" return analysis_fxn(self) - def transform_repeated(self, transform, make_deepcopy=True): - """Applies given transform repeatedly until no more changes can be made + def transform(self, transformation, make_deepcopy=True): + """Applies given Transformation repeatedly until no more changes can be made and returns a transformed ModelWrapper instance. If make_deepcopy is specified, operates on a new (deep)copy of model. - Transform must return (transformed_model, model_was_changed).""" + """ transformed_model = self if make_deepcopy: transformed_model = copy.deepcopy(self) model_was_changed = True while model_was_changed: - (transformed_model, model_was_changed) = transform(transformed_model) - return transformed_model - - def transform_single(self, transform, make_deepcopy=True): - """Applies given transform once and returns transformed ModelWrapper - instance. If make_deepcopy is specified, operates on a new (deep)copy of - model. Transform must return (transformed_model, model_was_changed), - although model_was_changed is ignored (see also apply_repeated).""" - transformed_model = self - if make_deepcopy: - transformed_model = copy.deepcopy(self) - (transformed_model, model_was_changed) = transform(transformed_model) + (transformed_model, model_was_changed) = transformation.apply( + transformed_model + ) return transformed_model def check_compatibility(self): diff --git a/src/finn/transformation/__init__.py b/src/finn/transformation/__init__.py index 727b8d32a60793c9f32df2df5120b069fb6528dc..3ddce04c11db01b2f6722bc843f0107621630936 100644 --- a/src/finn/transformation/__init__.py +++ b/src/finn/transformation/__init__.py @@ -2,9 +2,10 @@ Guide to writing FINN transformations ------------------------------------- -* Your transformation should take in a ModelWrapper, and return a tuple with - (transformed_model: ModelWrapper, model_was_changed: Bool) -* The transformations are meant to be applied using the .transform functions +* Your transformation must inherit the Transformation abstract base class. +* Your transformation's apply function should take in a ModelWrapper, and return + a tuple with (transformed_model: ModelWrapper, model_was_changed: Bool) +* The transformations are meant to be applied using the .transform function in ModelWrapper. This makes a deep copy of the input model by default, so you don't have to. * model_was_changed indicates whether your transformation made any changes to @@ -14,6 +15,17 @@ Guide to writing FINN transformations * You MUST return model_was_changed=False at some point when your transformation is called multiple times, otherwise apply_repeated() will loop infinitely. * If you cannot guarantee that the transformation will reach a fixed point, - you must declare this and notify the user to use .transform_single() instead - of .transform_repeated() + you must declare this, return model_was_changed = False and let the user + manually re-apply the transform. """ + +from abc import ABC, abstractmethod + + +class Transformation(ABC): + def __init__(self): + super().__init__() + + @abstractmethod + def apply(self, model): + pass diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py index 446e27ba50637fc0ef961f23b26f85bc8235b4fc..655ddd9842f37d59155aa0b12edeffecd89d65c1 100644 --- a/src/finn/transformation/batchnorm_to_affine.py +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -2,70 +2,73 @@ import numpy as np from onnx import TensorProto from onnx import helper as oh -import finn.transformation.infer_shapes as si +from finn.transformation import Transformation +from finn.transformation.infer_shapes import InferShapes -def batchnorm_to_affine(model): +class BatchNormToAffine(Transformation): """Replaces any test-time BatchNorm layers with Mul-Add layers.""" - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "BatchNormalization": - graph_modified = True - bn_input = n.input[0] - bn_output = n.output[0] - # extract batchnorm parameters as numpy arrays - scale = model.get_initializer(n.input[1]) - bias = model.get_initializer(n.input[2]) - mean = model.get_initializer(n.input[3]) - variance = model.get_initializer(n.input[4]) - epsilon = 1e-5 - # find A and B to compute batchnorm as affine transpose Ax+B - # TODO is a division by moving avg factor needed for variance? - A = scale / np.sqrt(epsilon + variance) - B = bias - (A * mean) - # see if we have surrounding Unsqueeze/Squeeze nodes we can remove - producer = model.find_producer(bn_input) - if producer is not None: - if producer.op_type == "Unsqueeze": - bn_input = producer.input[0] - consumer = model.find_consumer(bn_output) - if consumer is not None: - if consumer.op_type == "Squeeze": - bn_output = consumer.output[0] - data_shape = model.get_tensor_shape(bn_input) - # create value_info and initializers for Mul and Add constants - mul_const = oh.make_tensor_value_info( - model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape - ) - graph.value_info.append(mul_const) - model.set_initializer(mul_const.name, A) - mul_output = oh.make_tensor_value_info( - model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape - ) - graph.value_info.append(mul_output) - add_const = oh.make_tensor_value_info( - model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape - ) - graph.value_info.append(add_const) - model.set_initializer(add_const.name, B) - # create Mul and Add nodes to replace the batchnorm - mul_node = oh.make_node( - "Mul", [bn_input, mul_const.name], [mul_output.name] - ) - add_node = oh.make_node( - "Add", [mul_output.name, add_const.name], [bn_output] - ) - # insert where the batchnorm is to preserve topological ordering - graph.node.insert(node_ind, mul_node) - graph.node.insert(node_ind + 1, add_node) - # remove old nodes - graph.node.remove(n) - if consumer is not None: - graph.node.remove(consumer) - if producer is not None: - graph.node.remove(producer) - model = model.transform_single(si.infer_shapes) - return (model, graph_modified) + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "BatchNormalization": + graph_modified = True + bn_input = n.input[0] + bn_output = n.output[0] + # extract batchnorm parameters as numpy arrays + scale = model.get_initializer(n.input[1]) + bias = model.get_initializer(n.input[2]) + mean = model.get_initializer(n.input[3]) + variance = model.get_initializer(n.input[4]) + epsilon = 1e-5 + # find A and B to compute batchnorm as affine transpose Ax+B + # TODO is a division by moving avg factor needed for variance? + A = scale / np.sqrt(epsilon + variance) + B = bias - (A * mean) + # see if we have surrounding Unsqueeze/Squeeze nodes we can remove + producer = model.find_producer(bn_input) + if producer is not None: + if producer.op_type == "Unsqueeze": + bn_input = producer.input[0] + consumer = model.find_consumer(bn_output) + if consumer is not None: + if consumer.op_type == "Squeeze": + bn_output = consumer.output[0] + data_shape = model.get_tensor_shape(bn_input) + # create value_info and initializers for Mul and Add constants + mul_const = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape + ) + graph.value_info.append(mul_const) + model.set_initializer(mul_const.name, A) + mul_output = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape + ) + graph.value_info.append(mul_output) + add_const = oh.make_tensor_value_info( + model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape + ) + graph.value_info.append(add_const) + model.set_initializer(add_const.name, B) + # create Mul and Add nodes to replace the batchnorm + mul_node = oh.make_node( + "Mul", [bn_input, mul_const.name], [mul_output.name] + ) + add_node = oh.make_node( + "Add", [mul_output.name, add_const.name], [bn_output] + ) + # insert where the batchnorm is to preserve topological ordering + graph.node.insert(node_ind, mul_node) + graph.node.insert(node_ind + 1, add_node) + # remove old nodes + graph.node.remove(n) + if consumer is not None: + graph.node.remove(consumer) + if producer is not None: + graph.node.remove(producer) + model = model.transform(InferShapes()) + return (model, graph_modified) diff --git a/src/finn/transformation/fold_constants.py b/src/finn/transformation/fold_constants.py index a9331d3069ab81e919bd6a29f41243dc1e07ab51..5b27d906cc8ee4cbcaf7363001eb7297b5c21000 100644 --- a/src/finn/transformation/fold_constants.py +++ b/src/finn/transformation/fold_constants.py @@ -1,31 +1,34 @@ import finn.core.onnx_exec as oxe -import finn.transformation.infer_shapes as si +from finn.transformation import Transformation +from finn.transformation.infer_shapes import InferShapes -def fold_constants(model): +class FoldConstants(Transformation): """Replace the output of a node with const-only inputs with a precomputed result.""" - graph = model.graph - node_ind = 0 - graph_modified = False - execution_context = model.make_empty_exec_context() - for n in graph.node: - node_ind += 1 - node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input)) - node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits)) - node_out = n.output[0] - is_all_constant_inputs = len(node_inp_dyn) == 0 - ishape = model.get_tensor_shape(n.input[0]) - is_const_shape = (n.op_type == "Shape") and (ishape is not None) - if is_all_constant_inputs or is_const_shape: - # this node has no dynamic inputs, only constant ones -- so we can - # do constant folding. - oxe.execute_node(n, execution_context, graph) - # use the execution result as an initializer - model.set_initializer(node_out, execution_context[node_out]) - # remove old node - graph.node.remove(n) - graph_modified = True - if graph_modified: - model = model.transform_single(si.infer_shapes) - return (model, graph_modified) + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + execution_context = model.make_empty_exec_context() + for n in graph.node: + node_ind += 1 + node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input)) + node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits)) + node_out = n.output[0] + is_all_constant_inputs = len(node_inp_dyn) == 0 + ishape = model.get_tensor_shape(n.input[0]) + is_const_shape = (n.op_type == "Shape") and (ishape is not None) + if is_all_constant_inputs or is_const_shape: + # this node has no dynamic inputs, only constant ones -- so we can + # do constant folding. + oxe.execute_node(n, execution_context, graph) + # use the execution result as an initializer + model.set_initializer(node_out, execution_context[node_out]) + # remove old node + graph.node.remove(n) + graph_modified = True + if graph_modified: + model = model.transform(InferShapes()) + return (model, graph_modified) diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py index 44a4614c660c240f138ad1574cd05e292d1fcd0b..b6845312b06c409fe244ad1126d6ce37c0d85ca2 100644 --- a/src/finn/transformation/general.py +++ b/src/finn/transformation/general.py @@ -1,59 +1,68 @@ import finn.core.utils as util +from finn.transformation import Transformation -def give_unique_node_names(model): +class GiveUniqueNodeNames(Transformation): """Give unique names to each node in the graph using enumeration.""" - optype_count = {} - for n in model.graph.node: - if n.op_type not in optype_count.keys(): - optype_count[n.op_type] = 0 - n.name = "%s_%d" % (n.op_type, optype_count[n.op_type]) - optype_count[n.op_type] += 1 - # return model_was_changed = False as single iteration is always enough - return (model, False) + def apply(self, model): + optype_count = {} + for n in model.graph.node: + if n.op_type not in optype_count.keys(): + optype_count[n.op_type] = 0 + n.name = "%s_%d" % (n.op_type, optype_count[n.op_type]) + optype_count[n.op_type] += 1 + # return model_was_changed = False as single iteration is always enough + return (model, False) -def give_random_tensor_names(model): + +class GiveRandomTensorNames(Transformation): """Give random tensor names to all tensors.""" - names = model.get_all_tensor_names() - for name in names: - model.rename_tensor(name, util.random_string()) - # return model_was_changed = False as single iteration is always enough - return (model, False) + + def apply(self, model): + names = model.get_all_tensor_names() + for name in names: + model.rename_tensor(name, util.random_string()) + # return model_was_changed = False as single iteration is always enough + return (model, False) -def give_readable_tensor_names(model): +class GiveReadableTensorNames(Transformation): """Give more human-readable names to all internal tensors. It's recommended to apply give_unique_node_names prior to this transform.""" - # to ensure we can use rename_tensor safely (without renaming existing - # tensors) we start by giving random names to all tensors - model = model.transform_single(give_random_tensor_names) - graph = model.graph - for n in graph.node: - out_num = 0 - for o in n.output: - model.rename_tensor(o, "%s_out%d" % (n.name, out_num)) - out_num += 1 - init_in_num = 0 - for i in n.input: - if model.get_initializer(i) is not None: - model.rename_tensor(i, "%s_param%d" % (n.name, init_in_num)) - init_in_num += 1 - # give special names to the main model input and output - model.rename_tensor(model.graph.input[0].name, "global_in") - model.rename_tensor(model.graph.output[0].name, "global_out") - # return model_was_changed = False as single iteration is always enough - return (model, False) - - -def convert_sub_to_add(model): + + def apply(self, model): + # to ensure we can use rename_tensor safely (without renaming existing + # tensors) we start by giving random names to all tensors + model = model.transform(GiveRandomTensorNames()) + graph = model.graph + for n in graph.node: + out_num = 0 + for o in n.output: + model.rename_tensor(o, "%s_out%d" % (n.name, out_num)) + out_num += 1 + init_in_num = 0 + for i in n.input: + if model.get_initializer(i) is not None: + model.rename_tensor(i, "%s_param%d" % (n.name, init_in_num)) + init_in_num += 1 + # give special names to the main model input and output + model.rename_tensor(model.graph.input[0].name, "global_in") + model.rename_tensor(model.graph.output[0].name, "global_out") + # return model_was_changed = False as single iteration is always enough + return (model, False) + + +class ConvertSubToAdd(Transformation): """Convert sub nodes to add nodes of appropriate sign.""" - graph = model.graph - for n in graph.node: - if n.op_type == "Sub": - A = model.get_initializer(n.input[1]) - if A is not None: - n.op_type = "Add" - model.set_initializer(n.input[1], -A) - # return model_was_changed = False as single iteration is always enough - return (model, False) + + def apply(self, model): + graph = model.graph + for n in graph.node: + if n.op_type == "Sub": + A = model.get_initializer(n.input[1]) + if A is not None: + n.op_type = "Add" + model.set_initializer(n.input[1], -A) + # return model_was_changed = False as single iteration is always enough + return (model, False) diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py index 19d947b57d045d4d7f2523f0f392adaba5bb367a..a311012fc6631e76e75a37b8dc4d1b99d21ce7c7 100644 --- a/src/finn/transformation/infer_datatypes.py +++ b/src/finn/transformation/infer_datatypes.py @@ -1,4 +1,5 @@ from finn.core.datatype import DataType +from finn.transformation import Transformation def _infer_node_datatype(model, node): @@ -37,11 +38,13 @@ def _infer_node_datatype(model, node): return graph_modified -def infer_datatypes(model): +class InferDataTypes(Transformation): """Infer FINN DataType info for all intermediate/output tensors based on - inputs and node type.""" - graph = model.graph - graph_modified = False - for node in graph.node: - graph_modified |= _infer_node_datatype(model, node) - return (model, graph_modified) + inputs and node type.""" + + def apply(self, model): + graph = model.graph + graph_modified = False + for node in graph.node: + graph_modified |= _infer_node_datatype(model, node) + return (model, graph_modified) diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py index 063fb9b8b686a880283678e1bf779c09da209a5a..e92c6f81625a9d328bed3225a7730a943d6c9830 100644 --- a/src/finn/transformation/infer_shapes.py +++ b/src/finn/transformation/infer_shapes.py @@ -2,6 +2,8 @@ import onnx.helper as helper import onnx.shape_inference as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation import Transformation + def _make_shape_compatible_op(node): @@ -9,7 +11,7 @@ def _make_shape_compatible_op(node): shape inference with custom ops.""" assert node.domain == "finn" if node.op_type == "MultiThreshold": - return helper.make_node("ReLU", [node.input[0]], [node.output[0]]) + return helper.make_node("Relu", [node.input[0]], [node.output[0]]) else: raise Exception("No known shape-compatible op for %s" % node.op_type) @@ -44,12 +46,14 @@ def _restore_finn_ops(model, hidden_ops): pass -def infer_shapes(model): +class InferShapes(Transformation): """Ensure every tensor in the model has a specified shape (ValueInfo).""" - # hide your riches! - hidden_ops = _hide_finn_ops(model) - # call regular ONNX shape inference - model = ModelWrapper(si.infer_shapes(model.model)) - # bring back hidden ops - _restore_finn_ops(model, hidden_ops) - return (model, False) + + def apply(self, model): + # hide your riches! + hidden_ops = _hide_finn_ops(model) + # call regular ONNX shape inference + model = ModelWrapper(si.infer_shapes(model.model)) + # bring back hidden ops + _restore_finn_ops(model, hidden_ops) + return (model, False) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 5053b23c3e09bc5f32807cdf16a997c4647dd165..fb9e530d641063187dd75de30c8ca49936565bae 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -1,384 +1,416 @@ import numpy as np from onnx import helper as oh -import finn.transformation.infer_shapes as si from finn.core.datatype import DataType +from finn.transformation import Transformation +from finn.transformation.infer_shapes import InferShapes -def convert_sign_to_thres(model): +class ConvertSignToThres(Transformation): """Convert Sign node instances to MultiThreshold with threshold at 0.""" - graph = model.graph - graph_modified = False - node_ind = 0 - for n in graph.node: - node_ind += 1 - if n.op_type == "Sign": - sign_out_name = n.output[0] - # find consumer - consumer = model.find_consumer(sign_out_name) - assert consumer is not None - # change op type and create threshold - n.op_type = "MultiThreshold" - thres_param_name = model.make_new_valueinfo_name() - thres_param = np.asarray([[0]], dtype=np.float32) - n.input.append(thres_param_name) - n.domain = "finn" - model.set_initializer(thres_param_name, thres_param) - # convert 0,1 -> -1,+1 with 2*x-1 - out_shape = model.get_tensor_shape(sign_out_name) - # make a mul node - # note how set_initializer or set_tensor_shape is called before - # calling make_new_valueinfo_name again - mul_param_name = model.make_new_valueinfo_name() - model.set_initializer(mul_param_name, np.asarray([[2]], dtype=np.float32)) - mul_out_name = model.make_new_valueinfo_name() - model.set_tensor_shape(mul_out_name, out_shape) - mul_node = oh.make_node( - "Mul", [sign_out_name, mul_param_name], [mul_out_name] - ) - # make an add node - add_param_name = model.make_new_valueinfo_name() - model.set_initializer(add_param_name, np.asarray([[-1]], dtype=np.float32)) - add_out_name = model.make_new_valueinfo_name() - model.set_tensor_shape(add_out_name, out_shape) - add_node = oh.make_node( - "Add", [mul_out_name, add_param_name], [add_out_name] - ) - # add new nodes to graph at correct position - graph.node.insert(node_ind, mul_node) - graph.node.insert(node_ind + 1, add_node) - # rewrite consumer's input - consumer.input[0] = add_out_name - # add quantization annotations - model.set_tensor_datatype(sign_out_name, DataType.BINARY) - model.set_tensor_datatype(mul_out_name, DataType.UINT2) - model.set_tensor_datatype(add_out_name, DataType.BIPOLAR) - graph_modified = True - return (model, graph_modified) - - -def collapse_repeated_op(model, op_name, make_collapsed_param_fxn): - """Collapse repeated consecutive operations with constant parameters into - a single operation. make_collapsed_param_fxn must take two tensors and - return a tensor which gives the equivalent result using a single op. """ - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == op_name: - consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == op_name: - op0_param_name = n.input[1] - op1_param_name = consumer.input[1] - op0_param = model.get_initializer(op0_param_name) - op1_param = model.get_initializer(op1_param_name) - assert op0_param is not None - assert op1_param is not None - start_name = n.input[0] - end_name = consumer.output[0] - # compute the new parameter - new_param = make_collapsed_param_fxn(op0_param, op1_param) - # make and insert new node - new_node_param_name = op0_param_name - new_node = oh.make_node( - op_name, [start_name, new_node_param_name], [end_name] - ) - graph.node.insert(node_ind, new_node) - # replace parameter value - model.set_initializer(new_node_param_name, new_param) - # remove old nodes - graph.node.remove(n) - graph.node.remove(consumer) - graph_modified = True - model = model.transform_single(si.infer_shapes) - return (model, graph_modified) - -def collapse_repeated_add(model): - return collapse_repeated_op(model, "Add", lambda x, y: y + x) - - -def collapse_repeated_mul(model): - return collapse_repeated_op(model, "Mul", lambda x, y: y * x) - - -def move_add_past_mul(model): - """Move add operations past multiply operations. The aim is to have them - next to each other such that they can be collapsed into a single add.""" - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Add": - consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "Mul": - # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA) - # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA) - # assume input 0 is from the previous layer, input 1 is the - # trained (constant) parameter - mul_weight_name = consumer.input[1] - add_weight_name = n.input[1] - A = model.get_initializer(mul_weight_name) - B = model.get_initializer(add_weight_name) - assert A is not None - assert B is not None - start_name = n.input[0] - middle_name = n.output[0] - end_name = consumer.output[0] - # compute new param value for add - BA = B * A - # make and insert new nodes - new_mul = oh.make_node( - "Mul", [start_name, mul_weight_name], [middle_name] + def apply(self, model): + graph = model.graph + graph_modified = False + node_ind = 0 + for n in graph.node: + node_ind += 1 + if n.op_type == "Sign": + sign_out_name = n.output[0] + # find consumer + consumer = model.find_consumer(sign_out_name) + assert consumer is not None + # change op type and create threshold + n.op_type = "MultiThreshold" + thres_param_name = model.make_new_valueinfo_name() + thres_param = np.asarray([[0]], dtype=np.float32) + n.input.append(thres_param_name) + n.domain = "finn" + model.set_initializer(thres_param_name, thres_param) + # convert 0,1 -> -1,+1 with 2*x-1 + out_shape = model.get_tensor_shape(sign_out_name) + # make a mul node + # note how set_initializer or set_tensor_shape is called before + # calling make_new_valueinfo_name again + mul_param_name = model.make_new_valueinfo_name() + model.set_initializer( + mul_param_name, np.asarray([[2]], dtype=np.float32) + ) + mul_out_name = model.make_new_valueinfo_name() + model.set_tensor_shape(mul_out_name, out_shape) + mul_node = oh.make_node( + "Mul", [sign_out_name, mul_param_name], [mul_out_name] ) - new_add = oh.make_node( - "Add", [middle_name, add_weight_name], [end_name] + # make an add node + add_param_name = model.make_new_valueinfo_name() + model.set_initializer( + add_param_name, np.asarray([[-1]], dtype=np.float32) ) - graph.node.insert(node_ind, new_mul) - graph.node.insert(node_ind + 1, new_add) - # replace add value - model.set_initializer(add_weight_name, BA) - # remove old nodes - graph.node.remove(n) - graph.node.remove(consumer) + add_out_name = model.make_new_valueinfo_name() + model.set_tensor_shape(add_out_name, out_shape) + add_node = oh.make_node( + "Add", [mul_out_name, add_param_name], [add_out_name] + ) + # add new nodes to graph at correct position + graph.node.insert(node_ind, mul_node) + graph.node.insert(node_ind + 1, add_node) + # rewrite consumer's input + consumer.input[0] = add_out_name + # add quantization annotations + model.set_tensor_datatype(sign_out_name, DataType.BINARY) + model.set_tensor_datatype(mul_out_name, DataType.UINT2) + model.set_tensor_datatype(add_out_name, DataType.BIPOLAR) graph_modified = True - model = model.transform_single(si.infer_shapes) - return (model, graph_modified) + return (model, graph_modified) -def move_scalar_mul_past_matmul(model): - """Move scalar mul operations past matmul operations. We want to have muls - next to each other such that they can be collapsed into a single mul.""" - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Mul": - consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "MatMul": - mul_weight_name = n.input[1] - matmul_weight_name = consumer.input[1] - A = model.get_initializer(mul_weight_name) - W = model.get_initializer(matmul_weight_name) - assert A is not None - assert W is not None - start_name = n.input[0] - middle_name = n.output[0] - end_name = consumer.output[0] - mm_out_shape = model.get_tensor_shape(end_name) - if all(x == 1 for x in A.shape): - # if the mul is scalar, we can simply swap the order of ops - # make and insert new nodes - new_matmul = oh.make_node( - "MatMul", [start_name, matmul_weight_name], [middle_name] - ) - new_mul = oh.make_node( - "Mul", [middle_name, mul_weight_name], [end_name] +class CollapseRepeatedOp(Transformation): + """Collapse repeated consecutive operations with constant parameters into + a single operation. make_collapsed_param_fxn must take two tensors and + return a tensor which gives the equivalent result using a single op. """ + + def __init__(self, op_name, make_collapsed_param_fxn): + super().__init__() + self.op_name = op_name + self.make_collapsed_param_fxn = make_collapsed_param_fxn + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == self.op_name: + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == self.op_name: + op0_param_name = n.input[1] + op1_param_name = consumer.input[1] + op0_param = model.get_initializer(op0_param_name) + op1_param = model.get_initializer(op1_param_name) + assert op0_param is not None + assert op1_param is not None + start_name = n.input[0] + end_name = consumer.output[0] + # compute the new parameter + new_param = self.make_collapsed_param_fxn(op0_param, op1_param) + # make and insert new node + new_node_param_name = op0_param_name + new_node = oh.make_node( + self.op_name, [start_name, new_node_param_name], [end_name] ) - graph.node.insert(node_ind, new_matmul) - graph.node.insert(node_ind + 1, new_mul) - model.set_tensor_shape(middle_name, mm_out_shape) + graph.node.insert(node_ind, new_node) + # replace parameter value + model.set_initializer(new_node_param_name, new_param) # remove old nodes graph.node.remove(n) graph.node.remove(consumer) graph_modified = True - model = model.transform_single(si.infer_shapes) - return (model, graph_modified) + model = model.transform(InferShapes()) + return (model, graph_modified) -def move_scalar_add_past_matmul(model): - """Move scalar add operations past matmul operations. We want to have adds +class CollapseRepeatedAdd(CollapseRepeatedOp): + def __init__(self): + super().__init__("Add", lambda x, y: y + x) + + +class CollapseRepeatedMul(CollapseRepeatedOp): + def __init__(self): + super().__init__("Mul", lambda x, y: y * x) + + +class MoveAddPastMul(Transformation): + """Move add operations past multiply operations. The aim is to have them next to each other such that they can be collapsed into a single add.""" - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Add": - consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "MatMul": - add_weight_name = n.input[1] - matmul_weight_name = consumer.input[1] - A = model.get_initializer(add_weight_name) - W = model.get_initializer(matmul_weight_name) - assert A is not None - assert W is not None - start_name = n.input[0] - middle_name = n.output[0] - end_name = consumer.output[0] - mm_out_shape = model.get_tensor_shape(end_name) - if all(x == 1 for x in A.shape): - # if the add is scalar, we can move it past the matmul - # by taking it past the matmul with a dot product - Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W) - # update the add weight - model.set_initializer(add_weight_name, Anew) - new_matmul = oh.make_node( - "MatMul", [start_name, matmul_weight_name], [middle_name] + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Add": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Mul": + # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA) + # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA) + # assume input 0 is from the previous layer, input 1 is the + # trained (constant) parameter + mul_weight_name = consumer.input[1] + add_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + B = model.get_initializer(add_weight_name) + assert A is not None + assert B is not None + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + # compute new param value for add + BA = B * A + # make and insert new nodes + new_mul = oh.make_node( + "Mul", [start_name, mul_weight_name], [middle_name] ) new_add = oh.make_node( "Add", [middle_name, add_weight_name], [end_name] ) - graph.node.insert(node_ind, new_matmul) + graph.node.insert(node_ind, new_mul) graph.node.insert(node_ind + 1, new_add) - model.set_tensor_shape(middle_name, mm_out_shape) + # replace add value + model.set_initializer(add_weight_name, BA) # remove old nodes graph.node.remove(n) graph.node.remove(consumer) graph_modified = True - model = model.transform_single(si.infer_shapes) - return (model, graph_modified) + model = model.transform(InferShapes()) + return (model, graph_modified) -def absorb_add_into_multi_threshold(model): +class MoveScalarMulPastMatMul(Transformation): + """Move scalar mul operations past matmul operations. We want to have muls + next to each other such that they can be collapsed into a single mul.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Mul": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "MatMul": + mul_weight_name = n.input[1] + matmul_weight_name = consumer.input[1] + A = model.get_initializer(mul_weight_name) + W = model.get_initializer(matmul_weight_name) + assert A is not None + assert W is not None + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + mm_out_shape = model.get_tensor_shape(end_name) + if all(x == 1 for x in A.shape): + # if the mul is scalar, we can simply swap the order of ops + # make and insert new nodes + new_matmul = oh.make_node( + "MatMul", [start_name, matmul_weight_name], [middle_name] + ) + new_mul = oh.make_node( + "Mul", [middle_name, mul_weight_name], [end_name] + ) + graph.node.insert(node_ind, new_matmul) + graph.node.insert(node_ind + 1, new_mul) + model.set_tensor_shape(middle_name, mm_out_shape) + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + graph_modified = True + model = model.transform(InferShapes()) + return (model, graph_modified) + + +class MoveScalarAddPastMatMul(Transformation): + """Move scalar add operations past matmul operations. We want to have adds + next to each other such that they can be collapsed into a single add.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Add": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "MatMul": + add_weight_name = n.input[1] + matmul_weight_name = consumer.input[1] + A = model.get_initializer(add_weight_name) + W = model.get_initializer(matmul_weight_name) + assert A is not None + assert W is not None + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + mm_out_shape = model.get_tensor_shape(end_name) + if all(x == 1 for x in A.shape): + # if the add is scalar, we can move it past the matmul + # by taking it past the matmul with a dot product + Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W) + # update the add weight + model.set_initializer(add_weight_name, Anew) + new_matmul = oh.make_node( + "MatMul", [start_name, matmul_weight_name], [middle_name] + ) + new_add = oh.make_node( + "Add", [middle_name, add_weight_name], [end_name] + ) + graph.node.insert(node_ind, new_matmul) + graph.node.insert(node_ind + 1, new_add) + model.set_tensor_shape(middle_name, mm_out_shape) + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + graph_modified = True + model = model.transform(InferShapes()) + return (model, graph_modified) + + +class AbsorbAddIntoMultiThreshold(Transformation): """Absorb preceding Add ops into MultiThreshold by updating the threshold values.""" - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Add": - consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "MultiThreshold": - add_weight_name = n.input[1] - threshold_name = consumer.input[1] - A = model.get_initializer(add_weight_name) - T = model.get_initializer(threshold_name) - assert A is not None - assert T is not None - start_name = n.input[0] - # compute new thresholds and set initializer - Tnew = T - A.reshape(-1, T.shape[1]) - model.set_initializer(threshold_name, Tnew) - # wire add input directly to MultiThreshold - consumer.input[0] = start_name - # remove the add node - graph.node.remove(n) - graph_modified = True - return (model, graph_modified) - -def absorb_mul_into_multi_threshold(model): - """Absorb preceding Mul ops into MultiThreshold by updating the threshold - values. Only *positive* scalar/1D vectors can be absorbed.""" - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Mul": - mul_weight_name = n.input[1] - A = model.get_initializer(mul_weight_name) - assert A is not None - is_signed = (A < 0).any() - is_scalar = np.prod(A.shape) == 1 - is_1d = len(A.shape) == 2 and A.shape[0] == 1 - consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "MultiThreshold": - if not is_signed and (is_1d or is_scalar): + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Add": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "MultiThreshold": + add_weight_name = n.input[1] threshold_name = consumer.input[1] + A = model.get_initializer(add_weight_name) T = model.get_initializer(threshold_name) + assert A is not None assert T is not None start_name = n.input[0] # compute new thresholds and set initializer - Tnew = T / A.reshape(-1, T.shape[1]) - # TODO: need to handle negative A values correctly; produce - # mul sign mask and merge into preceding matmul? + Tnew = T - A.reshape(-1, T.shape[1]) model.set_initializer(threshold_name, Tnew) # wire add input directly to MultiThreshold consumer.input[0] = start_name - # remove the mul node + # remove the add node graph.node.remove(n) graph_modified = True - return (model, graph_modified) + return (model, graph_modified) + + +class AbsorbMulIntoMultiThreshold(Transformation): + """Absorb preceding Mul ops into MultiThreshold by updating the threshold + values. Only *positive* scalar/1D vectors can be absorbed.""" + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Mul": + mul_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + assert A is not None + is_signed = (A < 0).any() + is_scalar = np.prod(A.shape) == 1 + is_1d = len(A.shape) == 2 and A.shape[0] == 1 + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "MultiThreshold": + if not is_signed and (is_1d or is_scalar): + threshold_name = consumer.input[1] + T = model.get_initializer(threshold_name) + assert T is not None + start_name = n.input[0] + # compute new thresholds and set initializer + Tnew = T / A.reshape(-1, T.shape[1]) + # TODO: need to handle negative A values correctly; produce + # mul sign mask and merge into preceding matmul? + model.set_initializer(threshold_name, Tnew) + # wire add input directly to MultiThreshold + consumer.input[0] = start_name + # remove the mul node + graph.node.remove(n) + graph_modified = True + return (model, graph_modified) -def factor_out_mul_sign_magnitude(model): +class FactorOutMulSignMagnitude(Transformation): """Split multiply-by-constant nodes into two multiply-by-constant nodes, where the first node is a bipolar vector (of signs) and the second is a vector of magnitudes.""" - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "Mul": - mul_weight_name = n.input[1] - A = model.get_initializer(mul_weight_name) - assert A is not None - is_scalar = np.prod(A.shape) == 1 - is_1d = len(A.shape) == 2 and A.shape[0] == 1 - is_not_bipolar = ( - model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR - ) - is_signed = (A < 0).any() - if is_signed and (is_scalar or is_1d) and is_not_bipolar: - start_name = n.input[0] - in_shape = model.get_tensor_shape(start_name) - middle_name = model.make_new_valueinfo_name() - model.set_tensor_shape(middle_name, in_shape) - sign_mul_param_name = model.make_new_valueinfo_name() - # create new mul node with sign(A) as the operand - sgn = np.sign(A) - model.set_initializer(sign_mul_param_name, sgn) - model.set_tensor_datatype(sign_mul_param_name, DataType.BIPOLAR) - # replace original mul weight by magnitudes - model.set_initializer(mul_weight_name, np.abs(A)) - new_mul = oh.make_node( - "Mul", [start_name, sign_mul_param_name], [middle_name] + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Mul": + mul_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + assert A is not None + is_scalar = np.prod(A.shape) == 1 + is_1d = len(A.shape) == 2 and A.shape[0] == 1 + is_not_bipolar = ( + model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR ) - n.input[0] = middle_name - graph.node.insert(node_ind - 1, new_mul) - graph_modified = True - return (model, graph_modified) + is_signed = (A < 0).any() + if is_signed and (is_scalar or is_1d) and is_not_bipolar: + start_name = n.input[0] + in_shape = model.get_tensor_shape(start_name) + middle_name = model.make_new_valueinfo_name() + model.set_tensor_shape(middle_name, in_shape) + sign_mul_param_name = model.make_new_valueinfo_name() + # create new mul node with sign(A) as the operand + sgn = np.sign(A) + model.set_initializer(sign_mul_param_name, sgn) + model.set_tensor_datatype(sign_mul_param_name, DataType.BIPOLAR) + # replace original mul weight by magnitudes + model.set_initializer(mul_weight_name, np.abs(A)) + new_mul = oh.make_node( + "Mul", [start_name, sign_mul_param_name], [middle_name] + ) + n.input[0] = middle_name + graph.node.insert(node_ind - 1, new_mul) + graph_modified = True + return (model, graph_modified) -def absorb_1bit_mul_into_matmul(model): +class Absorb1BitMulIntoMatMul(Transformation): """Absorb bipolar or binary multiplications into the preciding matrix multiply.""" - graph = model.graph - node_ind = 0 - graph_modified = False - for n in graph.node: - node_ind += 1 - if n.op_type == "MatMul": - matmul_weight_name = n.input[1] - W = model.get_initializer(matmul_weight_name) - assert W is not None - consumer = model.find_consumer(n.output[0]) - if consumer is not None and consumer.op_type == "Mul": - mul_weight_name = consumer.input[1] - A = model.get_initializer(mul_weight_name) - assert A is not None - is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 - if is_1bit: - Wnew = A * W - assert Wnew.shape == W.shape - model.set_initializer(matmul_weight_name, Wnew) - n.output[0] = consumer.output[0] - graph.node.remove(consumer) - graph_modified = True - return (model, graph_modified) + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "MatMul": + matmul_weight_name = n.input[1] + W = model.get_initializer(matmul_weight_name) + assert W is not None + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Mul": + mul_weight_name = consumer.input[1] + A = model.get_initializer(mul_weight_name) + assert A is not None + is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 + if is_1bit: + Wnew = A * W + assert Wnew.shape == W.shape + model.set_initializer(matmul_weight_name, Wnew) + n.output[0] = consumer.output[0] + graph.node.remove(consumer) + graph_modified = True + return (model, graph_modified) -def round_thresholds(model): +class RoundThresholds(Transformation): """For MultiThreshold nodes operating on integer inputs, round up thresholds values to the nearest integer.""" - graph = model.graph - graph_modified = False - for n in graph.node: - if n.op_type == "MultiThreshold": - idtype = model.get_tensor_datatype(n.input[0]) - T = model.get_initializer(n.input[1]) - Tnew = np.ceil(T) - if idtype.is_integer() and (T != Tnew).any(): - # round up the thresholds to nearest integer - model.set_initializer(n.input[1], Tnew) - # use same datatype as inputs for thresholds - model.set_tensor_datatype(n.input[1], idtype) - graph_modified = True - return (model, graph_modified) + + def apply(self, model): + graph = model.graph + graph_modified = False + for n in graph.node: + if n.op_type == "MultiThreshold": + idtype = model.get_tensor_datatype(n.input[0]) + T = model.get_initializer(n.input[1]) + Tnew = np.ceil(T) + if idtype.is_integer() and (T != Tnew).any(): + # round up the thresholds to nearest integer + model.set_initializer(n.input[1], Tnew) + # use same datatype as inputs for thresholds + model.set_tensor_datatype(n.input[1], idtype) + graph_modified = True + return (model, graph_modified) diff --git a/tests/test_basic_onnx_exec.py b/tests/test_basic_onnx_exec.py index 30e9106febcbf9a84995ceb1d03f4c3782291d8d..c7b3da1b78385d36fc73790b22336242141d5255 100644 --- a/tests/test_basic_onnx_exec.py +++ b/tests/test_basic_onnx_exec.py @@ -5,15 +5,15 @@ import onnx import onnx.numpy_helper as np_helper import finn.core.onnx_exec as oxe -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes def test_mnist_onnx_download_extract_run(): # load the onnx model raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") model = ModelWrapper(raw_m) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb") diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index bb66f98f490069ffc12c4503a5a897c37bcf93b8..dec01b53a9081c6c04143e4465d603d66fc771a6 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -8,10 +8,10 @@ import torch from models.LFC import LFC import finn.core.onnx_exec as oxe -import finn.transformation.batchnorm_to_affine as tx -import finn.transformation.fold_constants as fc -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.batchnorm_to_affine import BatchNormToAffine +from finn.transformation.fold_constants import FoldConstants +from finn.transformation.infer_shapes import InferShapes export_onnx_path = "test_output_lfc.onnx" transformed_onnx_path = "test_output_lfc_transformed.onnx" @@ -27,9 +27,9 @@ def test_batchnorm_to_affine(): lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) - model = model.transform_single(si.infer_shapes) - model = model.transform_repeated(fc.fold_constants) - new_model = model.transform_single(tx.batchnorm_to_affine) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + new_model = model.transform(BatchNormToAffine()) # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 80850edb72af4f37e0e626df5b26513450d0c9d3..641e5e3c49b917ad06d649d797e6f9bc9f30f170 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -9,21 +9,25 @@ import torch from models.LFC import LFC import finn.core.onnx_exec as oxe -import finn.transformation.fold_constants as fc -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.fold_constants import FoldConstants +from finn.transformation.infer_shapes import InferShapes export_onnx_path = "test_output_lfc.onnx" # TODO get from config instead, hardcoded to Docker path for now -trained_lfc_checkpoint = ( +trained_lfc_w1a1_checkpoint = ( "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" ) +trained_lfc_w1a2_checkpoint = ( + "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W2A/checkpoints/best.tar" +) + -def test_brevitas_trained_lfc_pytorch(): +def test_brevitas_trained_lfc_w1a1_pytorch(): # load pretrained weights into LFC-w1a1 lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval() - checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu") lfc.load_state_dict(checkpoint["state_dict"]) # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") @@ -49,14 +53,68 @@ def test_brevitas_trained_lfc_pytorch(): assert np.isclose(produced, expected, atol=1e-4).all() -def test_brevitas_to_onnx_export_and_exec(): +def test_brevitas_trained_lfc_w1a2_pytorch(): + # load pretrained weights into LFC-w1a2 + lfc = LFC(weight_bit_width=1, act_bit_width=2, in_bit_width=2).eval() + checkpoint = torch.load(trained_lfc_w1a2_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + # load one of the test vectors + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) + input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float() + assert input_tensor.shape == (1, 1, 28, 28) + # do forward pass in PyTorch/Brevitas + produced = lfc.forward(input_tensor).detach().numpy() + expected = [ + [ + 4.598069, + -6.3698025, + 10.75695, + 0.3796571, + 1.4764442, + -5.4417515, + -1.8982856, + -5.610488, + 6.116698, + 0.21092065, + ] + ] + assert np.isclose(produced, expected, atol=1e-4).all() + + +def test_brevitas_to_onnx_export_and_exec_lfc_w1a1(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) - checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = ModelWrapper(export_onnx_path) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + # load one of the test vectors + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) + # run using FINN-based execution + input_dict = {"0": nph.to_array(input_tensor)} + output_dict = oxe.execute_onnx(model, input_dict) + produced = output_dict[list(output_dict.keys())[0]] + # run using PyTorch/Brevitas + input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float() + assert input_tensor.shape == (1, 1, 28, 28) + # do forward pass in PyTorch/Brevitas + expected = lfc.forward(input_tensor).detach().numpy() + assert np.isclose(produced, expected, atol=1e-3).all() + # remove the downloaded model and extracted files + os.remove(export_onnx_path) + + +def test_brevitas_to_onnx_export_and_exec_lfc_w1a2(): + lfc = LFC(weight_bit_width=1, act_bit_width=2, in_bit_width=2) + checkpoint = torch.load(trained_lfc_w1a2_checkpoint, map_location="cpu") lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) - model = model.transform_single(si.infer_shapes) - model = model.transform_repeated(fc.fold_constants) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) diff --git a/tests/test_collapse_repeated_op.py b/tests/test_collapse_repeated_op.py index d8cdde3c654873b368653347863af05317bc5bcb..d97cdbc3033000b94e499988155b3af829040109 100644 --- a/tests/test_collapse_repeated_op.py +++ b/tests/test_collapse_repeated_op.py @@ -3,9 +3,9 @@ import onnx.helper as oh from onnx import TensorProto import finn.core.onnx_exec as ox -import finn.transformation.infer_shapes as si -import finn.transformation.streamline as tx from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul def test_collapse_repeated_op(): @@ -30,12 +30,12 @@ def test_collapse_repeated_op(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) model.set_initializer("add_param_0", np.asarray([1, 3], dtype=np.float32)) model.set_initializer("add_param_1", np.asarray([-1, 3], dtype=np.float32)) model.set_initializer("mul_param_0", np.asarray([2, 4], dtype=np.float32)) model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32)) - new_model = model.transform_repeated(tx.collapse_repeated_add) - new_model = new_model.transform_repeated(tx.collapse_repeated_mul) + new_model = model.transform(CollapseRepeatedAdd()) + new_model = new_model.transform(CollapseRepeatedMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) diff --git a/tests/test_factor_out_mul_sign_magnitude.py b/tests/test_factor_out_mul_sign_magnitude.py index 2b4c0ba936366486a74b4c7fc8924e4d1aa67dda..3786492fefff046b098bce06cb2e624723b098ec 100644 --- a/tests/test_factor_out_mul_sign_magnitude.py +++ b/tests/test_factor_out_mul_sign_magnitude.py @@ -3,9 +3,9 @@ import onnx.helper as oh from onnx import TensorProto import finn.core.onnx_exec as ox -import finn.transformation.infer_shapes as si -import finn.transformation.streamline as tx from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.streamline import FactorOutMulSignMagnitude def test_factor_out_mul_sign_magnitude(): @@ -22,8 +22,8 @@ def test_factor_out_mul_sign_magnitude(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) model.set_initializer("mul_param", np.asarray([[-1, 4]], dtype=np.float32)) - new_model = model.transform_repeated(tx.factor_out_mul_sign_magnitude) + new_model = model.transform(FactorOutMulSignMagnitude()) inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) diff --git a/tests/test_fold_constants.py b/tests/test_fold_constants.py index 894887b0c91420a5755f2e737137f2fa47862b0b..09dbd95c27cec65183ec7dd6067ce187595fcf52 100644 --- a/tests/test_fold_constants.py +++ b/tests/test_fold_constants.py @@ -8,9 +8,9 @@ import onnx.numpy_helper as np_helper from models.LFC import LFC import finn.core.onnx_exec as oxe -import finn.transformation.fold_constants as fc -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.fold_constants import FoldConstants +from finn.transformation.infer_shapes import InferShapes export_onnx_path = "test_output_lfc.onnx" @@ -18,8 +18,8 @@ export_onnx_path = "test_output_lfc.onnx" def test_const_folding(): raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") model = ModelWrapper(raw_m) - model = model.transform_single(si.infer_shapes) - model = model.transform_single(fc.fold_constants) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) @@ -35,8 +35,8 @@ def test_const_folding_shapes(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) - model = model.transform_single(si.infer_shapes) - model = model.transform_repeated(fc.fold_constants) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) assert model.graph.node[0].op_type == "Reshape" assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28] assert list(model.get_tensor_shape("27")) == [1, 784] diff --git a/tests/test_general_transformation.py b/tests/test_general_transformation.py index d4376cca6051bd45669c5bb1a6102dc7db76e3b4..dd44502add93d269b8d48ed951620b6a36f9fb1b 100644 --- a/tests/test_general_transformation.py +++ b/tests/test_general_transformation.py @@ -1,13 +1,13 @@ from pkgutil import get_data -import finn.transformation.general as tg from finn.core.modelwrapper import ModelWrapper +from finn.transformation.general import GiveUniqueNodeNames def test_give_unique_node_names(): raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") model = ModelWrapper(raw_m) - model = model.transform_single(tg.give_unique_node_names) + model = model.transform(GiveUniqueNodeNames()) assert model.graph.node[0].name == "Reshape_0" assert model.graph.node[1].name == "Conv_0" assert model.graph.node[11].name == "Add_2" diff --git a/tests/test_infer_datatypes.py b/tests/test_infer_datatypes.py index fd1c99a1da5f5bdb93f0b8b089fddb20e3c078ed..e4269499c55eb89d9a9c268f79456ca6ac588028 100644 --- a/tests/test_infer_datatypes.py +++ b/tests/test_infer_datatypes.py @@ -4,12 +4,12 @@ import brevitas.onnx as bo import torch from models.LFC import LFC -import finn.transformation.fold_constants as fc -import finn.transformation.general as tg -import finn.transformation.infer_datatypes as id -import finn.transformation.infer_shapes as si from finn.core.datatype import DataType from finn.core.modelwrapper import ModelWrapper +from finn.transformation.fold_constants import FoldConstants +from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_shapes import InferShapes export_onnx_path = "test_output_lfc.onnx" # TODO get from config instead, hardcoded to Docker path for now @@ -24,11 +24,11 @@ def test_infer_datatypes(): lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) - model = model.transform_single(si.infer_shapes) - model = model.transform_repeated(fc.fold_constants) - model = model.transform_single(tg.give_unique_node_names) - model = model.transform_single(tg.give_readable_tensor_names) - model = model.transform_repeated(id.infer_datatypes) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + model = model.transform(InferDataTypes()) assert model.get_tensor_datatype("MatMul_0_out0") == DataType.INT32 assert model.get_tensor_datatype("MatMul_1_out0") == DataType.INT32 assert model.get_tensor_datatype("MatMul_2_out0") == DataType.INT32 diff --git a/tests/test_infer_shapes.py b/tests/test_infer_shapes.py index f4c27572cf30decce49f93fb32f12868440b00e2..20841b32275968ed842fdbbebffa7168b61b7e06 100644 --- a/tests/test_infer_shapes.py +++ b/tests/test_infer_shapes.py @@ -3,8 +3,9 @@ from pkgutil import get_data import numpy as np from onnx import TensorProto, helper -import finn.transformation.infer_shapes as si +import finn.core.utils as util from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes def test_infer_shapes(): @@ -26,7 +27,7 @@ def test_infer_shapes(): # thresholds for one channel have to be sorted to guarantee the correct behavior mt_thresh0_values = np.empty([8, 7], dtype=np.float32) for i in range(len(mt_thresh0_values)): - mt_thresh0_values[i] = np.sort(np.random.random_sample(7,) * 10) + mt_thresh0_values[i] = np.sort(np.random.random_sample(7) * 10) model.set_initializer(mt_thresh0.name, mt_thresh0_values) @@ -36,6 +37,9 @@ def test_infer_shapes(): ) Relu_node.output[0] = "mt_v0" + # explicitly remove any present shape from ReLU and MultiThreshold outputs + util.remove_by_name(model.graph.value_info, Relu_node.output[0]) + util.remove_by_name(model.graph.value_info, mt_node.output[0]) graph.node.insert(4, mt_node) # first check routine @@ -45,7 +49,7 @@ def test_infer_shapes(): ), "All tensors are already specified before the shape inference execution" # perform shape inference on mixed model - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) # second check routine # now all shapes should be specified and mt_node output shape is (1,8,28,28) diff --git a/tests/test_is_linear.py b/tests/test_is_linear.py index 1995604b7b0d1b816dea17500486c9ece1ac04c6..cce4612a93112608564b59d04c0370b195c5f2d4 100644 --- a/tests/test_is_linear.py +++ b/tests/test_is_linear.py @@ -2,8 +2,8 @@ import onnx.helper as oh from onnx import TensorProto import finn.analysis.topology as ta -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes def test_is_linear_linear(): @@ -24,7 +24,7 @@ def test_is_linear_linear(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) ret = model.analysis(ta.is_linear) assert ret["is_linear"] is True @@ -52,6 +52,6 @@ def test_is_linear_forked_node_output(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) ret = model.analysis(ta.is_linear) assert ret["is_linear"] is False diff --git a/tests/test_mixed_onnx_exec.py b/tests/test_mixed_onnx_exec.py index 0ac3f09da2f2b9fc27d88b3c90bd178adaeaf25b..75170da3ad3a8c9b24814e171b8ccdfde4fb74cd 100644 --- a/tests/test_mixed_onnx_exec.py +++ b/tests/test_mixed_onnx_exec.py @@ -2,8 +2,8 @@ import numpy as np from onnx import TensorProto, helper import finn.core.onnx_exec as oxe -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes def test_execute_mixed_model(): @@ -30,7 +30,7 @@ def test_execute_mixed_model(): model_def = helper.make_model(graph_def, producer_name="onnx-example") model = ModelWrapper(model_def) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) inputs = np.asarray( [ diff --git a/tests/test_move_add_past_mul.py b/tests/test_move_add_past_mul.py index a23827d0e91a543cf5595c816628169a251bf4ce..565bbce39b83d94352d2cd69d01d8270675bbe1f 100644 --- a/tests/test_move_add_past_mul.py +++ b/tests/test_move_add_past_mul.py @@ -3,9 +3,9 @@ import onnx.helper as oh from onnx import TensorProto import finn.core.onnx_exec as ox -import finn.transformation.infer_shapes as si -import finn.transformation.streamline as tx from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.streamline import MoveAddPastMul def test_move_add_past_mul_single(): @@ -26,10 +26,10 @@ def test_move_add_past_mul_single(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) model.set_initializer("add_param", np.asarray([1, 3], dtype=np.float32)) model.set_initializer("mul_param", np.asarray([2, 4], dtype=np.float32)) - new_model = model.transform_repeated(tx.move_add_past_mul) + new_model = model.transform(MoveAddPastMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) @@ -56,11 +56,11 @@ def test_move_add_past_mul_multi(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) model.set_initializer("add_param_0", np.asarray([1, 3], dtype=np.float32)) model.set_initializer("mul_param_0", np.asarray([2, 4], dtype=np.float32)) model.set_initializer("add_param_1", np.asarray([-1, 3], dtype=np.float32)) model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32)) - new_model = model.transform_repeated(tx.move_add_past_mul) + new_model = model.transform(MoveAddPastMul()) inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) diff --git a/tests/test_move_scalar_past_matmul.py b/tests/test_move_scalar_past_matmul.py index 7bbdd7dd8506437371147faa98310b8caa115318..c2771ce94d002ab62d33226771965140f8614ec1 100644 --- a/tests/test_move_scalar_past_matmul.py +++ b/tests/test_move_scalar_past_matmul.py @@ -3,9 +3,12 @@ import onnx.helper as oh from onnx import TensorProto import finn.core.onnx_exec as ox -import finn.transformation.infer_shapes as si -import finn.transformation.streamline as tx from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.streamline import ( + MoveScalarAddPastMatMul, + MoveScalarMulPastMatMul +) def test_move_scalar_mul_past_matmul(): @@ -26,12 +29,12 @@ def test_move_scalar_mul_past_matmul(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) model.set_initializer("mul_param", np.asarray([[3]], dtype=np.float32)) model.set_initializer( "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32) ) - new_model = model.transform_repeated(tx.move_scalar_mul_past_matmul) + new_model = model.transform(MoveScalarMulPastMatMul()) inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) assert new_model.graph.node[0].op_type == "MatMul" @@ -57,12 +60,12 @@ def test_move_scalar_add_past_matmul(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) model.set_initializer("add_param", np.asarray([[3]], dtype=np.float32)) model.set_initializer( "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32) ) - new_model = model.transform_repeated(tx.move_scalar_add_past_matmul) + new_model = model.transform(MoveScalarAddPastMatMul()) inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} assert ox.compare_execution(model, new_model, inp_dict) assert new_model.graph.node[0].op_type == "MatMul" diff --git a/tests/test_renaming.py b/tests/test_renaming.py index c13f7ee066e514064a2943493d810f48aa8f97be..aec1c8a10768f293ff9aaf44d7418c777837010c 100644 --- a/tests/test_renaming.py +++ b/tests/test_renaming.py @@ -5,18 +5,18 @@ import onnx import onnx.numpy_helper as np_helper import finn.core.onnx_exec as oxe -import finn.transformation.general as tg -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames +from finn.transformation.infer_shapes import InferShapes def test_renaming(): # load the onnx model raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") model = ModelWrapper(raw_m) - model = model.transform_single(si.infer_shapes) - model = model.transform_single(tg.give_unique_node_names) - model = model.transform_single(tg.give_readable_tensor_names) + model = model.transform(InferShapes()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) # do some basic checks assert model.graph.input[0].name == "global_in" assert model.graph.output[0].name == "global_out" @@ -27,8 +27,8 @@ def test_renaming(): assert model.graph.node[6].name == "Add_1" assert model.graph.node[6].input[1] == "Add_1_param0" # ensure running renaming twice still yields the same names - model = model.transform_single(tg.give_unique_node_names) - model = model.transform_single(tg.give_readable_tensor_names) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) assert model.graph.node[1].op_type == "Conv" assert model.graph.node[1].name == "Conv_0" assert model.graph.node[1].input[1] == "Conv_0_param0" diff --git a/tests/test_round_thresholds.py b/tests/test_round_thresholds.py index 28add7313201f1e30ecaf98d841f762a3f859835..afd1d86f868ed6b1343c3f6a1265a9016c3f8cef 100644 --- a/tests/test_round_thresholds.py +++ b/tests/test_round_thresholds.py @@ -2,9 +2,9 @@ import numpy as np from onnx import TensorProto, helper import finn.core.onnx_exec as oxe -import finn.transformation.streamline as sl from finn.core.datatype import DataType from finn.core.modelwrapper import ModelWrapper +from finn.transformation.streamline import RoundThresholds def test_round_thresholds(): @@ -27,7 +27,7 @@ def test_round_thresholds(): orig_n = oxe.execute_onnx(model, inp_dict_n)["out"] orig_c = oxe.execute_onnx(model, inp_dict_c)["out"] assert model.get_tensor_datatype("thresholds") == DataType.FLOAT32 - new_model = model.transform_repeated(sl.round_thresholds) + new_model = model.transform(RoundThresholds()) # rounded up thresholds should have same dtype as input assert new_model.get_tensor_datatype("thresholds") == DataType.INT8 new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"] diff --git a/tests/test_sign_to_thres.py b/tests/test_sign_to_thres.py index 4e18822cf9beceab009678876c1ca20f581ec22a..75327df3e5194f48f71c914bc62fc5a08588faff 100644 --- a/tests/test_sign_to_thres.py +++ b/tests/test_sign_to_thres.py @@ -8,10 +8,10 @@ import torch from models.LFC import LFC import finn.core.onnx_exec as oxe -import finn.transformation.fold_constants as fc -import finn.transformation.infer_shapes as si -import finn.transformation.streamline as sl from finn.core.modelwrapper import ModelWrapper +from finn.transformation.fold_constants import FoldConstants +from finn.transformation.infer_shapes import InferShapes +from finn.transformation.streamline import ConvertSignToThres export_onnx_path = "test_output_lfc.onnx" transformed_onnx_path = "test_output_lfc_transformed.onnx" @@ -27,9 +27,9 @@ def test_sign_to_thres(): lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) - model = model.transform_single(si.infer_shapes) - model = model.transform_repeated(fc.fold_constants) - new_model = model.transform_single(sl.convert_sign_to_thres) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + new_model = model.transform(ConvertSignToThres()) assert new_model.graph.node[3].op_type == "MultiThreshold" # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") diff --git a/tests/test_streamline.py b/tests/test_streamline.py index 1a93c6ff128b2169da4802d7367608f499e89beb..75b0301072bd603e0c25357fa022b66d1df4e467 100644 --- a/tests/test_streamline.py +++ b/tests/test_streamline.py @@ -9,13 +9,17 @@ import torch from models.LFC import LFC import finn.core.onnx_exec as oxe -import finn.transformation.batchnorm_to_affine as ba -import finn.transformation.fold_constants as fc -import finn.transformation.general as tg -import finn.transformation.infer_datatypes as di -import finn.transformation.infer_shapes as si import finn.transformation.streamline as sl from finn.core.modelwrapper import ModelWrapper +from finn.transformation.batchnorm_to_affine import BatchNormToAffine +from finn.transformation.fold_constants import FoldConstants +from finn.transformation.general import ( + ConvertSubToAdd, + GiveReadableTensorNames, + GiveUniqueNodeNames +) +from finn.transformation.infer_datatypes import InferDataTypes +from finn.transformation.infer_shapes import InferShapes export_onnx_path = "test_output_lfc.onnx" # TODO get from config instead, hardcoded to Docker path for now @@ -30,10 +34,10 @@ def test_streamline_lfc_w1a1(): lfc.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) - model = model.transform_single(si.infer_shapes) - model = model.transform_repeated(fc.fold_constants) - model = model.transform_single(tg.give_unique_node_names) - model = model.transform_single(tg.give_readable_tensor_names) + model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) @@ -42,26 +46,26 @@ def test_streamline_lfc_w1a1(): expected_ctx = oxe.execute_onnx(model, input_dict, True) expected = expected_ctx[model.graph.output[0].name] transforms = [ - tg.convert_sub_to_add, - ba.batchnorm_to_affine, - sl.convert_sign_to_thres, - sl.move_scalar_add_past_matmul, - sl.move_scalar_mul_past_matmul, - sl.move_add_past_mul, - sl.collapse_repeated_add, - sl.collapse_repeated_mul, - sl.absorb_add_into_multi_threshold, - sl.factor_out_mul_sign_magnitude, - sl.absorb_mul_into_multi_threshold, - sl.absorb_1bit_mul_into_matmul, - sl.round_thresholds, + ConvertSubToAdd(), + BatchNormToAffine(), + sl.ConvertSignToThres(), + sl.MoveScalarAddPastMatMul(), + sl.MoveScalarMulPastMatMul(), + sl.MoveAddPastMul(), + sl.CollapseRepeatedAdd(), + sl.CollapseRepeatedMul(), + sl.AbsorbAddIntoMultiThreshold(), + sl.FactorOutMulSignMagnitude(), + sl.AbsorbMulIntoMultiThreshold(), + sl.Absorb1BitMulIntoMatMul(), + sl.RoundThresholds(), ] trn_ind = 0 for trn in transforms: - model = model.transform_repeated(trn) - model = model.transform_single(tg.give_unique_node_names) - model = model.transform_single(tg.give_readable_tensor_names) - model = model.transform_repeated(di.infer_datatypes) + model = model.transform(trn) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + model = model.transform(InferDataTypes()) produced_ctx = oxe.execute_onnx(model, input_dict, True) produced = produced_ctx[model.graph.output[0].name] # model.save("%d-%s.onnx" % (trn_ind, trn.__name__)) diff --git a/tests/test_topology_checks.py b/tests/test_topology_checks.py index 2e9582a8dc23fe8ce5b3e409f19e8db7a332e9f7..e28ceac09ecfdd242285f6b9e355bd4c2cfd7e68 100644 --- a/tests/test_topology_checks.py +++ b/tests/test_topology_checks.py @@ -4,8 +4,8 @@ import onnx.helper as oh from onnx import TensorProto import finn.analysis.topology as ta -import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper +from finn.transformation.infer_shapes import InferShapes def test_all_tensors_f32(): @@ -26,7 +26,7 @@ def test_all_tensors_f32(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) ret = model.analysis(ta.all_tensors_f32) assert ret["all_tensors_f32"] is True @@ -47,7 +47,7 @@ def test_all_tensors_f32(): ) ) model = ModelWrapper(modelproto) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) ret = model.analysis(ta.all_tensors_f32) assert ret["all_tensors_f32"] is False @@ -55,7 +55,7 @@ def test_all_tensors_f32(): def test_node_inputs_in_expected_order(): raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") model = ModelWrapper(raw_m) - model = model.transform_single(si.infer_shapes) + model = model.transform(InferShapes()) ret = model.analysis(ta.node_inputs_in_expected_order) # this model has an (unnecessary) dynamic reshape for its weight tensor # and so it fails the check