diff --git a/src/finn/analysis/__init__.py b/src/finn/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18e1efb37e9be82de26d933cccaf64a85bc8ff22 --- /dev/null +++ b/src/finn/analysis/__init__.py @@ -0,0 +1,8 @@ +""" +How to write an analysis pass for FINN +-------------------------------------- + +An analysis pass traverses the graph structure and produces information about +certain properties. The convention is to take in a ModelWrapper, and return +a dictionary of named properties that the analysis extracts. +""" diff --git a/src/finn/analysis/topology.py b/src/finn/analysis/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..9150e2b118cc9e17464e5c9866f83005576d40df --- /dev/null +++ b/src/finn/analysis/topology.py @@ -0,0 +1,47 @@ +import numpy as np + + +def is_linear(model): + """Checks whether the given model graph is linear. This is done by looking + at the fan-out of each tensor. All tensors have a fan-out <= 1 in a linear + graph. Returns {"is_linear", Bool}""" + per_tensor_fanouts = get_per_tensor_fanouts(model) + # check for tensors that have fanout > 1 + multi_fanouts = list(filter(lambda x: x[1] > 1, per_tensor_fanouts.items())) + return {"is_linear": len(multi_fanouts) == 0} + + +def get_per_tensor_fanouts(model): + """Returns a dictionary of (tensor_name, tensor_fanout) for the model.""" + # make execution context to get a list of tensors + per_tensor_fanouts = model.make_empty_exec_context() + # replace every tensor with its fanout + for tensor_name in per_tensor_fanouts.keys(): + per_tensor_fanouts[tensor_name] = model.get_tensor_fanout(tensor_name) + return per_tensor_fanouts + + +def all_tensors_f32(model): + """Checks whether all tensors have a float32 dtype, extra quantization + annotations notwithstanding.""" + all_tensors = model.make_empty_exec_context().items() + non_f32_tensors = filter(lambda x: x[1].dtype != np.float32, all_tensors) + return {"all_tensors_f32": len(list(non_f32_tensors)) == 0} + + +def node_inputs_in_expected_order(model): + """Verifies that the node inputs are ordered in the way that FINN expects + them. When a node has a mixture of static (= constant, initialized) inputs + and dynamic inputs, the dynamic input should come first, followed by the + static one. Only verifiable for a small subset of op_types for now.""" + op_types = ["MatMul", "Conv", "Add", "Mul"] + nodes = filter(lambda x: x.op_type in op_types, model.graph.node) + all_OK = True + for n in nodes: + all_OK = all_OK and len(list(n.input)) == 2 + # input 0 should be dynamic, no initializer + all_OK = all_OK and (model.get_initializer(n.input[0]) is None) + # input 1 should be static (unless eltwise add) + if n.op_type != "Add": + all_OK = all_OK and (model.get_initializer(n.input[1]) is not None) + return {"node_inputs_in_expected_order": all_OK} diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index 4b892694cdf3138e6c6e804d5c6ac044833896f4..b8be06d4b55875937d434cb941efc0891b1655c6 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -1,7 +1,9 @@ import copy import onnx +import onnx.helper as oh import onnx.numpy_helper as np_helper +from onnx import TensorProto import finn.core.utils as util @@ -13,12 +15,14 @@ class ModelWrapper: def __init__(self, onnx_model_proto, make_deepcopy=False): """Creates a ModelWrapper instance. onnx_model_proto can be either a ModelProto instance, or a string - with the path to a stored .onnx file on disk. + with the path to a stored .onnx file on disk, or serialized bytes. The make_deepcopy option controls whether a deep copy of the ModelProto is made internally. """ if isinstance(onnx_model_proto, str): self._model_proto = onnx.load(onnx_model_proto) + elif isinstance(onnx_model_proto, bytes): + self._model_proto = onnx.load_from_string(onnx_model_proto) else: if make_deepcopy: self._model_proto = copy.deepcopy(onnx_model_proto) @@ -45,6 +49,10 @@ class ModelWrapper: """Save the wrapper ONNX ModelProto into a file with given name.""" onnx.save(self._model_proto, filename) + def analysis(self, analysis_fxn): + """Run given anaylsis_fxn on this model and return resulting dict.""" + return analysis_fxn(self) + def transform_repeated(self, transform, make_deepcopy=True): """Applies given transform repeatedly until no more changes can be made and returns a transformed ModelWrapper instance. @@ -94,6 +102,21 @@ class ModelWrapper: except ValueError: return None + def set_tensor_shape(self, tensor_name, tensor_shape): + """Assign shape in ValueInfoProto for tensor with given name.""" + dtype = TensorProto.FLOAT + new_vi = oh.make_tensor_value_info(tensor_name, dtype, tensor_shape) + # find what container tis tensor's ValueInfo lives in + # if not found anywhere, we assume it's a new value_info + target_container = self.graph.value_info + if util.get_by_name(self.graph.input, tensor_name) is not None: + target_container = self.graph.input + if util.get_by_name(self.graph.output, tensor_name) is not None: + target_container = self.graph.output + # remove from target container and add new + util.remove_by_name(target_container, tensor_name) + target_container.append(new_vi) + def set_initializer(self, tensor_name, tensor_value): """Set the initializer value for tensor with given name.""" graph = self._model_proto.graph @@ -110,6 +133,35 @@ class ModelWrapper: pass # create and insert new initializer graph.initializer.append(tensor_init_proto) + # set shape + self.set_tensor_shape(tensor_name, list(tensor_value.shape)) + + def rename_tensor(self, old_name, new_name): + """Rename a tensor from old_name to new_name.""" + graph = self.graph + # sweep over inputs + if util.get_by_name(graph.input, old_name) is not None: + util.get_by_name(graph.input, old_name).name = new_name + # sweep over outputs + if util.get_by_name(graph.output, old_name) is not None: + util.get_by_name(graph.output, old_name).name = new_name + # sweep over value_info + if util.get_by_name(graph.value_info, old_name) is not None: + util.get_by_name(graph.value_info, old_name).name = new_name + # sweep over quantization annotations + if ( + util.get_by_name(graph.quantization_annotation, old_name, "tensor_name") + is not None + ): + util.get_by_name( + graph.quantization_annotation, old_name, "tensor_name" + ).tensor_name = new_name + # sweep over node i/o + for n in graph.node: + if old_name in n.input: + n.input[list(n.input).index(old_name)] = new_name + if old_name in n.output: + n.output[list(n.output).index(old_name)] = new_name def get_initializer(self, tensor_name): """Get the initializer value for tensor with given name, if any.""" @@ -187,3 +239,13 @@ class ModelWrapper: for o in n.output: ret = ret and (self.get_tensor_shape(o) is not None) return ret + + def get_tensor_fanout(self, tensor_name): + """Return the number of nodes for which the tensor with given name is + as input.""" + graph = self.graph + fanout = 0 + for n in graph.node: + if tensor_name in n.input: + fanout += 1 + return fanout diff --git a/src/finn/core/utils.py b/src/finn/core/utils.py index b49c3b223fe038967436a8d5db35ccb21f6750f6..d10da1ed362528a754879931f9bf3ef373bd9832 100644 --- a/src/finn/core/utils.py +++ b/src/finn/core/utils.py @@ -9,3 +9,20 @@ def valueinfo_to_tensor(vi): return np.zeros( dims, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[vi.type.tensor_type.elem_type] ) + + +def get_by_name(container, name, name_field="name"): + """Return item from container by .name field if it exists, None otherwise""" + names = [getattr(x, name_field) for x in container] + try: + ind = names.index(name) + return container[ind] + except ValueError: + return None + + +def remove_by_name(container, name, name_field="name"): + """Remove item from container by .name field if it exists""" + item = get_by_name(container, name, name_field) + if item is not None: + container.remove(item) diff --git a/src/finn/data/onnx/mnist-conv/model.onnx b/src/finn/data/onnx/mnist-conv/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fc1a3f733c6e6243dd23dacb125b7a372de55a50 Binary files /dev/null and b/src/finn/data/onnx/mnist-conv/model.onnx differ diff --git a/src/finn/data/onnx/mnist-conv/test_data_set_0/input_0.pb b/src/finn/data/onnx/mnist-conv/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..f0072d51a480af615e92312608f75993be9f1136 Binary files /dev/null and b/src/finn/data/onnx/mnist-conv/test_data_set_0/input_0.pb differ diff --git a/src/finn/data/onnx/mnist-conv/test_data_set_0/output_0.pb b/src/finn/data/onnx/mnist-conv/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..a6f4cdf92e27aaab21e44098b5b28e16048098e2 --- /dev/null +++ b/src/finn/data/onnx/mnist-conv/test_data_set_0/output_0.pb @@ -0,0 +1,2 @@ + +J(ãêsDU®ÄŒtÍEÚ'DWQeÄYôÐÄQôÄ3vÂNKBÄñ³Ä \ No newline at end of file diff --git a/src/finn/transformation/__init__.py b/src/finn/transformation/__init__.py index 385e1bc6c02ca0df972694371233ed82a8211686..727b8d32a60793c9f32df2df5120b069fb6528dc 100644 --- a/src/finn/transformation/__init__.py +++ b/src/finn/transformation/__init__.py @@ -2,10 +2,11 @@ Guide to writing FINN transformations ------------------------------------- -* Your transformation should take in an ONNX model, and return a tuple with - (transformed_model: ModelProto, model_was_changed: Bool) -* The original model should not be modified, use e.g. copy.deepcopy() if you - want to work on a copy of the graph for modifications. +* Your transformation should take in a ModelWrapper, and return a tuple with + (transformed_model: ModelWrapper, model_was_changed: Bool) +* The transformations are meant to be applied using the .transform functions + in ModelWrapper. This makes a deep copy of the input model by default, so + you don't have to. * model_was_changed indicates whether your transformation made any changes to the model. If you know your transformation needs to be called only once and repeated calls have no further effect, you can return False even if the model @@ -13,5 +14,6 @@ Guide to writing FINN transformations * You MUST return model_was_changed=False at some point when your transformation is called multiple times, otherwise apply_repeated() will loop infinitely. * If you cannot guarantee that the transformation will reach a fixed point, - you must declare this and return only the transformed model instead of a tuple. + you must declare this and notify the user to use .transform_single() instead + of .transform_repeated() """ diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py index 646c69ef066280554b74b2f56afae76a9eab4326..446e27ba50637fc0ef961f23b26f85bc8235b4fc 100644 --- a/src/finn/transformation/batchnorm_to_affine.py +++ b/src/finn/transformation/batchnorm_to_affine.py @@ -8,7 +8,6 @@ import finn.transformation.infer_shapes as si def batchnorm_to_affine(model): """Replaces any test-time BatchNorm layers with Mul-Add layers.""" graph = model.graph - nodes_to_remove = [] node_ind = 0 graph_modified = False for n in graph.node: @@ -27,18 +26,15 @@ def batchnorm_to_affine(model): # TODO is a division by moving avg factor needed for variance? A = scale / np.sqrt(epsilon + variance) B = bias - (A * mean) - nodes_to_remove += [n] # see if we have surrounding Unsqueeze/Squeeze nodes we can remove producer = model.find_producer(bn_input) if producer is not None: if producer.op_type == "Unsqueeze": bn_input = producer.input[0] - nodes_to_remove += [producer] consumer = model.find_consumer(bn_output) if consumer is not None: if consumer.op_type == "Squeeze": bn_output = consumer.output[0] - nodes_to_remove += [consumer] data_shape = model.get_tensor_shape(bn_input) # create value_info and initializers for Mul and Add constants mul_const = oh.make_tensor_value_info( @@ -65,9 +61,11 @@ def batchnorm_to_affine(model): # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) - # delete marked nodes (batchnorm and (un)squeezing) - for n in nodes_to_remove: - graph.node.remove(n) - graph_modified = True + # remove old nodes + graph.node.remove(n) + if consumer is not None: + graph.node.remove(consumer) + if producer is not None: + graph.node.remove(producer) model = model.transform_single(si.infer_shapes) return (model, graph_modified) diff --git a/src/finn/transformation/fold_constants.py b/src/finn/transformation/fold_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a951f057ee6b123497c19b3ef2ab020cfd7e7b03 --- /dev/null +++ b/src/finn/transformation/fold_constants.py @@ -0,0 +1,26 @@ +import finn.core.onnx_exec as oxe + + +def fold_constants(model): + """Replace the output of a node with const-only inputs with a precomputed + result.""" + graph = model.graph + node_ind = 0 + graph_modified = False + execution_context = model.make_empty_exec_context() + for n in graph.node: + node_ind += 1 + node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input)) + node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits)) + node_out = n.output[0] + if len(node_inp_dyn) == 0: + # this node has no dynamic inputs, only constant ones -- so we can + # do constant folding. + oxe.execute_node(n, execution_context, graph) + # use the execution result as an initializer + model.set_initializer(node_out, execution_context[node_out]) + # remove old node + graph.node.remove(n) + graph_modified = True + # TODO remove unused tensors? + return (model, graph_modified) diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py index 55bf899cf54268254d8cecdadae9970711f92077..b8894ed24305f9aa07f6ce7b17ff7b5426ff9451 100644 --- a/src/finn/transformation/general.py +++ b/src/finn/transformation/general.py @@ -1,12 +1,26 @@ -import copy - - def give_unique_node_names(model): """Give unique names to each node in the graph using enumeration.""" - new_model = copy.deepcopy(model) node_count = 0 - for n in new_model.graph.node: + for n in model.graph.node: n.name = "%s_%d" % (n.op_type, node_count) node_count += 1 # return model_was_changed = False as single iteration is always enough - return (new_model, False) + return (model, False) + + +def give_readable_tensor_names(model): + """Give more human-readable names to all internal tensors. It's recommended + to apply give_unique_node_names prior to this transform.""" + graph = model.graph + for n in graph.node: + out_num = 0 + for o in n.output: + model.rename_tensor(o, "%s_out%d" % (n.name, out_num)) + out_num += 1 + init_in_num = 0 + for i in n.input: + if model.get_initializer(i) is not None: + model.rename_tensor(i, "%s_param%d" % (n.name, init_in_num)) + init_in_num += 1 + # return model_was_changed = False as single iteration is always enough + return (model, False) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py new file mode 100644 index 0000000000000000000000000000000000000000..f54b78f0448711fd8dd531936f43e185bb0766ca --- /dev/null +++ b/src/finn/transformation/streamline.py @@ -0,0 +1,176 @@ +import numpy as np +from onnx import helper as oh + +import finn.transformation.infer_shapes as si + + +def collapse_repeated_op(model, op_name, make_collapsed_param_fxn): + """Collapse repeated consecutive operations with constant parameters into + a single operation. make_collapsed_param_fxn must take two tensors and + return a tensor which gives the equivalent result using a single op. """ + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == op_name: + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == op_name: + op0_param_name = n.input[1] + op1_param_name = consumer.input[1] + op0_param = model.get_initializer(op0_param_name) + op1_param = model.get_initializer(op1_param_name) + assert op0_param is not None + assert op1_param is not None + start_name = n.input[0] + end_name = consumer.output[0] + # compute the new parameter + new_param = make_collapsed_param_fxn(op0_param, op1_param) + # make and insert new node + new_node_param_name = op0_param_name + new_node = oh.make_node( + op_name, [start_name, new_node_param_name], [end_name] + ) + graph.node.insert(node_ind, new_node) + # replace parameter value + model.set_initializer(new_node_param_name, new_param) + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + graph_modified = True + model = model.transform_single(si.infer_shapes) + return (model, graph_modified) + + +def collapse_repeated_add(model): + return collapse_repeated_op(model, "Add", lambda x, y: y + x) + + +def collapse_repeated_mul(model): + return collapse_repeated_op(model, "Mul", lambda x, y: y * x) + + +def move_add_past_mul(model): + """Move add operations past multiply operations. The aim is to have them + next to each other such that they can be collapsed into a single add.""" + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Add": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Mul": + # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA) + # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA) + # assume input 0 is from the previous layer, input 1 is the + # trained (constant) parameter + mul_weight_name = consumer.input[1] + add_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + B = model.get_initializer(add_weight_name) + assert A is not None + assert B is not None + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + # compute new param value for add + BA = B * A + # make and insert new nodes + new_mul = oh.make_node( + "Mul", [start_name, mul_weight_name], [middle_name] + ) + new_add = oh.make_node( + "Add", [middle_name, add_weight_name], [end_name] + ) + graph.node.insert(node_ind, new_mul) + graph.node.insert(node_ind + 1, new_add) + # replace add value + model.set_initializer(add_weight_name, BA) + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + graph_modified = True + model = model.transform_single(si.infer_shapes) + return (model, graph_modified) + + +def move_scalar_mul_past_matmul(model): + """Move scalar mul operations past matmul operations. We want to have muls + next to each other such that they can be collapsed into a single mul.""" + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Mul": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "MatMul": + mul_weight_name = n.input[1] + matmul_weight_name = consumer.input[1] + A = model.get_initializer(mul_weight_name) + W = model.get_initializer(matmul_weight_name) + assert A is not None + assert W is not None + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + if all(x == 1 for x in A.shape): + # if the mul is scalar, we can simply swap the order of ops + # make and insert new nodes + new_matmul = oh.make_node( + "MatMul", [start_name, matmul_weight_name], [middle_name] + ) + new_mul = oh.make_node( + "Mul", [middle_name, mul_weight_name], [end_name] + ) + graph.node.insert(node_ind, new_matmul) + graph.node.insert(node_ind + 1, new_mul) + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + graph_modified = True + model = model.transform_single(si.infer_shapes) + return (model, graph_modified) + + +def move_scalar_add_past_matmul(model): + """Move scalar add operations past matmul operations. We want to have adds + next to each other such that they can be collapsed into a single add.""" + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Add": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "MatMul": + add_weight_name = n.input[1] + matmul_weight_name = consumer.input[1] + A = model.get_initializer(add_weight_name) + W = model.get_initializer(matmul_weight_name) + assert A is not None + assert W is not None + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + if all(x == 1 for x in A.shape): + # if the add is scalar, we can move it past the matmul + # by taking it past the matmul with a dot product + Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W) + # update the add weight + model.set_initializer(add_weight_name, Anew) + new_matmul = oh.make_node( + "MatMul", [start_name, matmul_weight_name], [middle_name] + ) + new_add = oh.make_node( + "Add", [middle_name, add_weight_name], [end_name] + ) + graph.node.insert(node_ind, new_matmul) + graph.node.insert(node_ind + 1, new_add) + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + graph_modified = True + model = model.transform_single(si.infer_shapes) + return (model, graph_modified) diff --git a/tests/test_basic_onnx_exec.py b/tests/test_basic_onnx_exec.py index 8f2a0efb05f3529c06b99f819b52689d255473ff..30e9106febcbf9a84995ceb1d03f4c3782291d8d 100644 --- a/tests/test_basic_onnx_exec.py +++ b/tests/test_basic_onnx_exec.py @@ -1,46 +1,27 @@ -import hashlib -import os -import shutil +from pkgutil import get_data import numpy as np import onnx import onnx.numpy_helper as np_helper -import wget import finn.core.onnx_exec as oxe import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper -mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist" -mnist_onnx_filename = "mnist.tar.gz" -mnist_onnx_local_dir = "/tmp/mnist_onnx" - def test_mnist_onnx_download_extract_run(): - try: - os.remove("/tmp/" + mnist_onnx_filename) - except OSError: - pass - dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp") - shutil.unpack_archive(dl_ret, mnist_onnx_local_dir) - with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f: - assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d" # load the onnx model - model = ModelWrapper(mnist_onnx_local_dir + "/mnist/model.onnx") + raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") + model = ModelWrapper(raw_m) model = model.transform_single(si.infer_shapes) # load one of the test vectors - input_tensor = onnx.TensorProto() - output_tensor = onnx.TensorProto() - with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f: - input_tensor.ParseFromString(f.read()) - with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/output_0.pb", "rb") as f: - output_tensor.ParseFromString(f.read()) + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) + output_tensor = onnx.load_tensor_from_string(raw_o) # run using FINN-based execution input_dict = {"Input3": np_helper.to_array(input_tensor)} output_dict = oxe.execute_onnx(model, input_dict) assert np.isclose( np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3 ).all() - # remove the downloaded model and extracted files - os.remove(dl_ret) - shutil.rmtree(mnist_onnx_local_dir) diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index bed7ad8bb400cac5851f9c4e671d5f660745dc4a..5f0e64d3dd53bba05ca29459fe8cc075d0cb1752 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -1,14 +1,13 @@ import os -import shutil from functools import reduce from operator import mul +from pkgutil import get_data import brevitas.onnx as bo import numpy as np import onnx import onnx.numpy_helper as nph import torch -import wget from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op from torch.nn import BatchNorm1d, Dropout, Module, ModuleList @@ -23,9 +22,6 @@ LAST_FC_PER_OUT_CH_SCALING = False IN_DROPOUT = 0.2 HIDDEN_DROPOUT = 0.2 -mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist" -mnist_onnx_filename = "mnist.tar.gz" -mnist_onnx_local_dir = "/tmp/mnist_onnx" export_onnx_path = "test_output_lfc.onnx" transformed_onnx_path = "test_output_lfc_transformed.onnx" # TODO get from config instead, hardcoded to Docker path for now @@ -98,21 +94,11 @@ def test_batchnorm_to_affine(): model = ModelWrapper(export_onnx_path) model = model.transform_single(si.infer_shapes) new_model = model.transform_single(tx.batchnorm_to_affine) - try: - os.remove("/tmp/" + mnist_onnx_filename) - except OSError: - pass - dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp") - shutil.unpack_archive(dl_ret, mnist_onnx_local_dir) # load one of the test vectors - input_tensor = onnx.TensorProto() - with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f: - input_tensor.ParseFromString(f.read()) + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) input_dict = {"0": nph.to_array(input_tensor)} output_original = oxe.execute_onnx(model, input_dict)["53"] output_transformed = oxe.execute_onnx(new_model, input_dict)["53"] assert np.isclose(output_transformed, output_original, atol=1e-3).all() - # remove the downloaded model and extracted files - os.remove(dl_ret) - shutil.rmtree(mnist_onnx_local_dir) os.remove(export_onnx_path) diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py index 805557e6b0c50ccbf936c1c7ecef9980dc1b614d..83babecf501788ee1600b49bc3ef461712877517 100644 --- a/tests/test_brevitas_export.py +++ b/tests/test_brevitas_export.py @@ -1,14 +1,13 @@ import os -import shutil from functools import reduce from operator import mul +from pkgutil import get_data import brevitas.onnx as bo import numpy as np import onnx import onnx.numpy_helper as nph import torch -import wget from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op from torch.nn import BatchNorm1d, Dropout, Module, ModuleList @@ -22,9 +21,6 @@ LAST_FC_PER_OUT_CH_SCALING = False IN_DROPOUT = 0.2 HIDDEN_DROPOUT = 0.2 -mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist" -mnist_onnx_filename = "mnist.tar.gz" -mnist_onnx_local_dir = "/tmp/mnist_onnx" export_onnx_path = "test_output_lfc.onnx" # TODO get from config instead, hardcoded to Docker path for now trained_lfc_checkpoint = ( @@ -124,17 +120,9 @@ def test_brevitas_trained_lfc_pytorch(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval() checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") lfc.load_state_dict(checkpoint["state_dict"]) - # download some MNIST test data - try: - os.remove("/tmp/" + mnist_onnx_filename) - except OSError: - pass - dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp") - shutil.unpack_archive(dl_ret, mnist_onnx_local_dir) # load one of the test vectors - input_tensor = onnx.TensorProto() - with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f: - input_tensor.ParseFromString(f.read()) + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float() assert input_tensor.shape == (1, 1, 28, 28) # do forward pass in PyTorch/Brevitas @@ -154,9 +142,6 @@ def test_brevitas_trained_lfc_pytorch(): ] ] assert np.isclose(produced, expected, atol=1e-4).all() - # remove the downloaded model and extracted files - os.remove(dl_ret) - shutil.rmtree(mnist_onnx_local_dir) def test_brevitas_to_onnx_export_and_exec(): @@ -166,16 +151,9 @@ def test_brevitas_to_onnx_export_and_exec(): bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform_single(si.infer_shapes) - try: - os.remove("/tmp/" + mnist_onnx_filename) - except OSError: - pass - dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp") - shutil.unpack_archive(dl_ret, mnist_onnx_local_dir) # load one of the test vectors - input_tensor = onnx.TensorProto() - with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f: - input_tensor.ParseFromString(f.read()) + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) # run using FINN-based execution input_dict = {"0": nph.to_array(input_tensor)} output_dict = oxe.execute_onnx(model, input_dict) @@ -187,6 +165,4 @@ def test_brevitas_to_onnx_export_and_exec(): expected = lfc.forward(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() # remove the downloaded model and extracted files - os.remove(dl_ret) - shutil.rmtree(mnist_onnx_local_dir) os.remove(export_onnx_path) diff --git a/tests/test_collapse_repeated_op.py b/tests/test_collapse_repeated_op.py new file mode 100644 index 0000000000000000000000000000000000000000..224df9c3b37de85278f48480460ed186934e3487 --- /dev/null +++ b/tests/test_collapse_repeated_op.py @@ -0,0 +1,43 @@ +import numpy as np +import onnx.helper as oh +from onnx import TensorProto + +import finn.core.onnx_exec as ox +import finn.transformation.infer_shapes as si +import finn.transformation.streamline as tx +from finn.core.modelwrapper import ModelWrapper + + +def test_collapse_repeated_op(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param_0 = oh.make_tensor_value_info("add_param_0", TensorProto.FLOAT, [2]) + mul_param_0 = oh.make_tensor_value_info("mul_param_0", TensorProto.FLOAT, [2]) + add_param_1 = oh.make_tensor_value_info("add_param_1", TensorProto.FLOAT, [2]) + mul_param_1 = oh.make_tensor_value_info("mul_param_1", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param_0, mul_param_0, add_param_1, mul_param_1], + nodes=[ + oh.make_node("Add", ["top_in", "add_param_0"], ["middle_0"]), + oh.make_node("Add", ["middle_0", "add_param_1"], ["middle_1"]), + oh.make_node("Mul", ["middle_1", "mul_param_0"], ["middle_2"]), + oh.make_node("Mul", ["middle_2", "mul_param_1"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + model.set_initializer("add_param_0", np.asarray([1, 3], dtype=np.float32)) + model.set_initializer("add_param_1", np.asarray([-1, 3], dtype=np.float32)) + model.set_initializer("mul_param_0", np.asarray([2, 4], dtype=np.float32)) + model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32)) + new_model = model.transform_repeated(tx.collapse_repeated_add) + new_model = new_model.transform_repeated(tx.collapse_repeated_mul) + inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} + out_orig = ox.execute_onnx(model, inp_dict)["top_out"] + out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] + assert np.isclose(out_orig, out_transformed).all() diff --git a/tests/test_fold_constants.py b/tests/test_fold_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6bb3815dbbb495b3f8c518505abbdf54759b59 --- /dev/null +++ b/tests/test_fold_constants.py @@ -0,0 +1,26 @@ +from pkgutil import get_data + +import numpy as np +import onnx +import onnx.numpy_helper as np_helper + +import finn.core.onnx_exec as oxe +import finn.transformation.fold_constants as fc +import finn.transformation.infer_shapes as si +from finn.core.modelwrapper import ModelWrapper + + +def test_const_folding(): + raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") + model = ModelWrapper(raw_m) + model = model.transform_single(si.infer_shapes) + model = model.transform_single(fc.fold_constants) + raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") + raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb") + input_tensor = onnx.load_tensor_from_string(raw_i) + output_tensor = onnx.load_tensor_from_string(raw_o) + input_dict = {"Input3": np_helper.to_array(input_tensor)} + output_dict = oxe.execute_onnx(model, input_dict) + assert np.isclose( + np_helper.to_array(output_tensor), output_dict["Plus214_Output_0"], atol=1e-3 + ).all() diff --git a/tests/test_general_transformation.py b/tests/test_general_transformation.py index 037083884071d5146f1ab6e71bce83e9bb794fdc..090c72371a6a14a70c69e2f5f3889dcc1e3b5617 100644 --- a/tests/test_general_transformation.py +++ b/tests/test_general_transformation.py @@ -1,31 +1,13 @@ -import hashlib -import os -import shutil - -import wget +from pkgutil import get_data import finn.transformation.general as tg from finn.core.modelwrapper import ModelWrapper -mnist_onnx_url_base = "https://onnxzoo.blob.core.windows.net/models/opset_8/mnist" -mnist_onnx_filename = "mnist.tar.gz" -mnist_onnx_local_dir = "/tmp/mnist_onnx" - def test_give_unique_node_names(): - try: - os.remove("/tmp/" + mnist_onnx_filename) - except OSError: - pass - dl_ret = wget.download(mnist_onnx_url_base + "/" + mnist_onnx_filename, out="/tmp") - shutil.unpack_archive(dl_ret, mnist_onnx_local_dir) - with open(mnist_onnx_local_dir + "/mnist/model.onnx", "rb") as f: - assert hashlib.md5(f.read()).hexdigest() == "d7cd24a0a76cd492f31065301d468c3d" - model = ModelWrapper(mnist_onnx_local_dir + "/mnist/model.onnx") + raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") + model = ModelWrapper(raw_m) model = model.transform_single(tg.give_unique_node_names) assert model.graph.node[0].name == "Reshape_0" assert model.graph.node[1].name == "Conv_1" assert model.graph.node[11].name == "Add_11" - # remove the downloaded model and extracted files - os.remove(dl_ret) - shutil.rmtree(mnist_onnx_local_dir) diff --git a/tests/test_is_linear.py b/tests/test_is_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..1995604b7b0d1b816dea17500486c9ece1ac04c6 --- /dev/null +++ b/tests/test_is_linear.py @@ -0,0 +1,57 @@ +import onnx.helper as oh +from onnx import TensorProto + +import finn.analysis.topology as ta +import finn.transformation.infer_shapes as si +from finn.core.modelwrapper import ModelWrapper + + +def test_is_linear_linear(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param = oh.make_tensor_value_info("add_param", TensorProto.FLOAT, [2]) + mul_param = oh.make_tensor_value_info("mul_param", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param, mul_param], + nodes=[ + oh.make_node("Add", ["top_in", "add_param"], ["middle"]), + oh.make_node("Mul", ["middle", "mul_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + ret = model.analysis(ta.is_linear) + assert ret["is_linear"] is True + + +def test_is_linear_forked_node_output(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param = oh.make_tensor_value_info("add_param", TensorProto.FLOAT, [2]) + mul0_param = oh.make_tensor_value_info("mul0_param", TensorProto.FLOAT, [2]) + mul1_param = oh.make_tensor_value_info("mul1_param", TensorProto.FLOAT, [2]) + mul0_res = oh.make_tensor_value_info("mul0_res", TensorProto.FLOAT, [2]) + mul1_res = oh.make_tensor_value_info("mul1_res", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param, mul0_param, mul1_param, mul0_res, mul1_res], + nodes=[ + oh.make_node("Add", ["top_in", "add_param"], ["middle"]), + oh.make_node("Mul", ["middle", "mul0_param"], ["mul0_res"]), + oh.make_node("Mul", ["middle", "mul1_param"], ["mul1_res"]), + oh.make_node("Add", ["mul0_res", "mul1_res"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + ret = model.analysis(ta.is_linear) + assert ret["is_linear"] is False diff --git a/tests/test_move_add_past_mul.py b/tests/test_move_add_past_mul.py new file mode 100644 index 0000000000000000000000000000000000000000..b19e1ce326b0d14da86e4324c62db1df688eb886 --- /dev/null +++ b/tests/test_move_add_past_mul.py @@ -0,0 +1,70 @@ +import numpy as np +import onnx.helper as oh +from onnx import TensorProto + +import finn.core.onnx_exec as ox +import finn.transformation.infer_shapes as si +import finn.transformation.streamline as tx +from finn.core.modelwrapper import ModelWrapper + + +def test_move_add_past_mul_single(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param = oh.make_tensor_value_info("add_param", TensorProto.FLOAT, [2]) + mul_param = oh.make_tensor_value_info("mul_param", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param, mul_param], + nodes=[ + oh.make_node("Add", ["top_in", "add_param"], ["middle"]), + oh.make_node("Mul", ["middle", "mul_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + model.set_initializer("add_param", np.asarray([1, 3], dtype=np.float32)) + model.set_initializer("mul_param", np.asarray([2, 4], dtype=np.float32)) + new_model = model.transform_repeated(tx.move_add_past_mul) + inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} + out_orig = ox.execute_onnx(model, inp_dict)["top_out"] + out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] + assert np.isclose(out_orig, out_transformed).all() + + +def test_move_add_past_mul_multi(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param_0 = oh.make_tensor_value_info("add_param_0", TensorProto.FLOAT, [2]) + mul_param_0 = oh.make_tensor_value_info("mul_param_0", TensorProto.FLOAT, [2]) + add_param_1 = oh.make_tensor_value_info("add_param_1", TensorProto.FLOAT, [2]) + mul_param_1 = oh.make_tensor_value_info("mul_param_1", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param_0, mul_param_0, add_param_1, mul_param_1], + nodes=[ + oh.make_node("Add", ["top_in", "add_param_0"], ["middle_0"]), + oh.make_node("Mul", ["middle_0", "mul_param_0"], ["middle_1"]), + oh.make_node("Add", ["middle_1", "add_param_1"], ["middle_2"]), + oh.make_node("Mul", ["middle_2", "mul_param_1"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + model.set_initializer("add_param_0", np.asarray([1, 3], dtype=np.float32)) + model.set_initializer("mul_param_0", np.asarray([2, 4], dtype=np.float32)) + model.set_initializer("add_param_1", np.asarray([-1, 3], dtype=np.float32)) + model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32)) + new_model = model.transform_repeated(tx.move_add_past_mul) + inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)} + out_orig = ox.execute_onnx(model, inp_dict)["top_out"] + out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] + assert np.isclose(out_orig, out_transformed).all() diff --git a/tests/test_move_scalar_past_matmul.py b/tests/test_move_scalar_past_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cd35d425a1fbf3109b36a7fc6b24bd23706a47 --- /dev/null +++ b/tests/test_move_scalar_past_matmul.py @@ -0,0 +1,74 @@ +import numpy as np +import onnx.helper as oh +from onnx import TensorProto + +import finn.core.onnx_exec as ox +import finn.transformation.infer_shapes as si +import finn.transformation.streamline as tx +from finn.core.modelwrapper import ModelWrapper + + +def test_move_scalar_mul_past_matmul(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2]) + mul_param = oh.make_tensor_value_info("mul_param", TensorProto.FLOAT, [1, 1]) + matmul_param = oh.make_tensor_value_info("matmul_param", TensorProto.FLOAT, [2, 2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[mul_param, matmul_param], + nodes=[ + oh.make_node("Mul", ["top_in", "mul_param"], ["middle"]), + oh.make_node("MatMul", ["middle", "matmul_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + model.set_initializer("mul_param", np.asarray([[3]], dtype=np.float32)) + model.set_initializer( + "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32) + ) + new_model = model.transform_repeated(tx.move_scalar_mul_past_matmul) + inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} + out_orig = ox.execute_onnx(model, inp_dict)["top_out"] + out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] + assert np.isclose(out_orig, out_transformed).all() + assert new_model.graph.node[0].op_type == "MatMul" + assert new_model.graph.node[1].op_type == "Mul" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] + + +def test_move_scalar_add_past_matmul(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [1, 2]) + add_param = oh.make_tensor_value_info("add_param", TensorProto.FLOAT, [1, 1]) + matmul_param = oh.make_tensor_value_info("matmul_param", TensorProto.FLOAT, [2, 2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [1, 2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param, matmul_param], + nodes=[ + oh.make_node("Add", ["top_in", "add_param"], ["middle"]), + oh.make_node("MatMul", ["middle", "matmul_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + model.set_initializer("add_param", np.asarray([[3]], dtype=np.float32)) + model.set_initializer( + "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32) + ) + new_model = model.transform_repeated(tx.move_scalar_add_past_matmul) + inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)} + out_orig = ox.execute_onnx(model, inp_dict)["top_out"] + out_transformed = ox.execute_onnx(new_model, inp_dict)["top_out"] + assert np.isclose(out_orig, out_transformed).all() + assert new_model.graph.node[0].op_type == "MatMul" + assert new_model.graph.node[1].op_type == "Add" + assert new_model.graph.node[0].output[0] == new_model.graph.node[1].input[0] diff --git a/tests/test_renaming.py b/tests/test_renaming.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7b5d8050fc034b42a99a01f3f92faad77802ee --- /dev/null +++ b/tests/test_renaming.py @@ -0,0 +1,20 @@ +from pkgutil import get_data + +import finn.transformation.general as tg +import finn.transformation.infer_shapes as si +from finn.core.modelwrapper import ModelWrapper + + +def test_renaming(): + # load the onnx model + raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") + model = ModelWrapper(raw_m) + model = model.transform_single(si.infer_shapes) + model = model.transform_single(tg.give_unique_node_names) + model = model.transform_single(tg.give_readable_tensor_names) + assert model.graph.node[1].op_type == "Conv" + assert model.graph.node[1].name == "Conv_1" + assert model.graph.node[1].input[1] == "Conv_1_param0" + assert model.graph.node[6].op_type == "Add" + assert model.graph.node[6].name == "Add_6" + assert model.graph.node[6].input[1] == "Add_6_param0" diff --git a/tests/test_topology_checks.py b/tests/test_topology_checks.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9582a8dc23fe8ce5b3e409f19e8db7a332e9f7 --- /dev/null +++ b/tests/test_topology_checks.py @@ -0,0 +1,62 @@ +from pkgutil import get_data + +import onnx.helper as oh +from onnx import TensorProto + +import finn.analysis.topology as ta +import finn.transformation.infer_shapes as si +from finn.core.modelwrapper import ModelWrapper + + +def test_all_tensors_f32(): + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param = oh.make_tensor_value_info("add_param", TensorProto.FLOAT, [2]) + mul_param = oh.make_tensor_value_info("mul_param", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param, mul_param], + nodes=[ + oh.make_node("Add", ["top_in", "add_param"], ["middle"]), + oh.make_node("Mul", ["middle", "mul_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + ret = model.analysis(ta.all_tensors_f32) + assert ret["all_tensors_f32"] is True + + top_in = oh.make_tensor_value_info("top_in", TensorProto.FLOAT, [2]) + add_param = oh.make_tensor_value_info("add_param", TensorProto.INT8, [2]) + mul_param = oh.make_tensor_value_info("mul_param", TensorProto.FLOAT, [2]) + top_out = oh.make_tensor_value_info("top_out", TensorProto.FLOAT, [2]) + modelproto = oh.make_model( + oh.make_graph( + name="test", + inputs=[top_in], + outputs=[top_out], + value_info=[add_param, mul_param], + nodes=[ + oh.make_node("Add", ["top_in", "add_param"], ["middle"]), + oh.make_node("Mul", ["middle", "mul_param"], ["top_out"]), + ], + ) + ) + model = ModelWrapper(modelproto) + model = model.transform_single(si.infer_shapes) + ret = model.analysis(ta.all_tensors_f32) + assert ret["all_tensors_f32"] is False + + +def test_node_inputs_in_expected_order(): + raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx") + model = ModelWrapper(raw_m) + model = model.transform_single(si.infer_shapes) + ret = model.analysis(ta.node_inputs_in_expected_order) + # this model has an (unnecessary) dynamic reshape for its weight tensor + # and so it fails the check + assert ret["node_inputs_in_expected_order"] is False