From 742d222aa358100fbf5c141f39aa26c002f16020 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sun, 10 Nov 2019 00:01:01 +0000
Subject: [PATCH] [Transform] refactor transformation interface and fix tests

---
 src/finn/core/modelwrapper.py                 |  21 +-
 src/finn/transformation/__init__.py           |  22 +-
 .../transformation/batchnorm_to_affine.py     | 131 ++--
 src/finn/transformation/fold_constants.py     |  55 +-
 src/finn/transformation/general.py            | 103 +--
 src/finn/transformation/infer_datatypes.py    |  17 +-
 src/finn/transformation/infer_shapes.py       |  19 +-
 src/finn/transformation/streamline.py         | 682 +++++++++---------
 tests/test_basic_onnx_exec.py                 |   4 +-
 tests/test_batchnorm_to_affine.py             |  12 +-
 tests/test_brevitas_export.py                 |  12 +-
 tests/test_collapse_repeated_op.py            |  10 +-
 tests/test_factor_out_mul_sign_magnitude.py   |   8 +-
 tests/test_fold_constants.py                  |  12 +-
 tests/test_general_transformation.py          |   4 +-
 tests/test_infer_datatypes.py                 |  18 +-
 tests/test_infer_shapes.py                    |   4 +-
 tests/test_is_linear.py                       |   6 +-
 tests/test_mixed_onnx_exec.py                 |   4 +-
 tests/test_move_add_past_mul.py               |  12 +-
 tests/test_move_scalar_past_matmul.py         |  15 +-
 tests/test_renaming.py                        |  14 +-
 tests/test_round_thresholds.py                |   4 +-
 tests/test_sign_to_thres.py                   |  12 +-
 tests/test_streamline.py                      |  56 +-
 tests/test_topology_checks.py                 |   8 +-
 26 files changed, 664 insertions(+), 601 deletions(-)

diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index 8a2aa66e7..c734413b7 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -54,28 +54,19 @@ class ModelWrapper:
         """Run given anaylsis_fxn on this model and return resulting dict."""
         return analysis_fxn(self)
 
-    def transform_repeated(self, transform, make_deepcopy=True):
-        """Applies given transform repeatedly until no more changes can be made
+    def transform(self, transformation, make_deepcopy=True):
+        """Applies given Transformation repeatedly until no more changes can be made
         and returns a transformed ModelWrapper instance.
         If make_deepcopy is specified, operates on a new (deep)copy of model.
-        Transform must return (transformed_model, model_was_changed)."""
+        """
         transformed_model = self
         if make_deepcopy:
             transformed_model = copy.deepcopy(self)
         model_was_changed = True
         while model_was_changed:
-            (transformed_model, model_was_changed) = transform(transformed_model)
-        return transformed_model
-
-    def transform_single(self, transform, make_deepcopy=True):
-        """Applies given transform once and returns transformed ModelWrapper
-        instance. If make_deepcopy is specified, operates on a new (deep)copy of
-        model. Transform must return (transformed_model, model_was_changed),
-        although model_was_changed is ignored (see also apply_repeated)."""
-        transformed_model = self
-        if make_deepcopy:
-            transformed_model = copy.deepcopy(self)
-        (transformed_model, model_was_changed) = transform(transformed_model)
+            (transformed_model, model_was_changed) = transformation.apply(
+                transformed_model
+            )
         return transformed_model
 
     def check_compatibility(self):
diff --git a/src/finn/transformation/__init__.py b/src/finn/transformation/__init__.py
index 727b8d32a..3ddce04c1 100644
--- a/src/finn/transformation/__init__.py
+++ b/src/finn/transformation/__init__.py
@@ -2,9 +2,10 @@
 Guide to writing FINN transformations
 -------------------------------------
 
-* Your transformation should take in a ModelWrapper, and return a tuple with
- (transformed_model: ModelWrapper, model_was_changed: Bool)
-* The transformations are meant to be applied using the .transform functions
+* Your transformation must inherit the Transformation abstract base class.
+* Your transformation's apply function should take in a ModelWrapper, and return
+  a tuple with (transformed_model: ModelWrapper, model_was_changed: Bool)
+* The transformations are meant to be applied using the .transform function
   in ModelWrapper. This makes a deep copy of the input model by default, so
   you don't have to.
 * model_was_changed indicates whether your transformation made any changes to
@@ -14,6 +15,17 @@ Guide to writing FINN transformations
 * You MUST return model_was_changed=False at some point when your transformation
   is called multiple times, otherwise apply_repeated() will loop infinitely.
 * If you cannot guarantee that the transformation will reach a fixed point,
-  you must declare this and notify the user to use .transform_single() instead
-  of .transform_repeated()
+  you must declare this, return model_was_changed = False and let the user
+  manually re-apply the transform.
 """
+
+from abc import ABC, abstractmethod
+
+
+class Transformation(ABC):
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def apply(self, model):
+        pass
diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py
index 446e27ba5..655ddd984 100644
--- a/src/finn/transformation/batchnorm_to_affine.py
+++ b/src/finn/transformation/batchnorm_to_affine.py
@@ -2,70 +2,73 @@ import numpy as np
 from onnx import TensorProto
 from onnx import helper as oh
 
-import finn.transformation.infer_shapes as si
+from finn.transformation import Transformation
+from finn.transformation.infer_shapes import InferShapes
 
 
-def batchnorm_to_affine(model):
+class BatchNormToAffine(Transformation):
     """Replaces any test-time BatchNorm layers with Mul-Add layers."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "BatchNormalization":
-            graph_modified = True
-            bn_input = n.input[0]
-            bn_output = n.output[0]
-            # extract batchnorm parameters as numpy arrays
-            scale = model.get_initializer(n.input[1])
-            bias = model.get_initializer(n.input[2])
-            mean = model.get_initializer(n.input[3])
-            variance = model.get_initializer(n.input[4])
-            epsilon = 1e-5
-            # find A and B to compute batchnorm as affine transpose Ax+B
-            # TODO is a division by moving avg factor needed for variance?
-            A = scale / np.sqrt(epsilon + variance)
-            B = bias - (A * mean)
-            # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
-            producer = model.find_producer(bn_input)
-            if producer is not None:
-                if producer.op_type == "Unsqueeze":
-                    bn_input = producer.input[0]
-            consumer = model.find_consumer(bn_output)
-            if consumer is not None:
-                if consumer.op_type == "Squeeze":
-                    bn_output = consumer.output[0]
-            data_shape = model.get_tensor_shape(bn_input)
-            # create value_info and initializers for Mul and Add constants
-            mul_const = oh.make_tensor_value_info(
-                model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
-            )
-            graph.value_info.append(mul_const)
-            model.set_initializer(mul_const.name, A)
-            mul_output = oh.make_tensor_value_info(
-                model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape
-            )
-            graph.value_info.append(mul_output)
-            add_const = oh.make_tensor_value_info(
-                model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape
-            )
-            graph.value_info.append(add_const)
-            model.set_initializer(add_const.name, B)
-            # create Mul and Add nodes to replace the batchnorm
-            mul_node = oh.make_node(
-                "Mul", [bn_input, mul_const.name], [mul_output.name]
-            )
-            add_node = oh.make_node(
-                "Add", [mul_output.name, add_const.name], [bn_output]
-            )
-            # insert where the batchnorm is to preserve topological ordering
-            graph.node.insert(node_ind, mul_node)
-            graph.node.insert(node_ind + 1, add_node)
-            # remove old nodes
-            graph.node.remove(n)
-            if consumer is not None:
-                graph.node.remove(consumer)
-            if producer is not None:
-                graph.node.remove(producer)
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "BatchNormalization":
+                graph_modified = True
+                bn_input = n.input[0]
+                bn_output = n.output[0]
+                # extract batchnorm parameters as numpy arrays
+                scale = model.get_initializer(n.input[1])
+                bias = model.get_initializer(n.input[2])
+                mean = model.get_initializer(n.input[3])
+                variance = model.get_initializer(n.input[4])
+                epsilon = 1e-5
+                # find A and B to compute batchnorm as affine transpose Ax+B
+                # TODO is a division by moving avg factor needed for variance?
+                A = scale / np.sqrt(epsilon + variance)
+                B = bias - (A * mean)
+                # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
+                producer = model.find_producer(bn_input)
+                if producer is not None:
+                    if producer.op_type == "Unsqueeze":
+                        bn_input = producer.input[0]
+                consumer = model.find_consumer(bn_output)
+                if consumer is not None:
+                    if consumer.op_type == "Squeeze":
+                        bn_output = consumer.output[0]
+                data_shape = model.get_tensor_shape(bn_input)
+                # create value_info and initializers for Mul and Add constants
+                mul_const = oh.make_tensor_value_info(
+                    model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
+                )
+                graph.value_info.append(mul_const)
+                model.set_initializer(mul_const.name, A)
+                mul_output = oh.make_tensor_value_info(
+                    model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape
+                )
+                graph.value_info.append(mul_output)
+                add_const = oh.make_tensor_value_info(
+                    model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape
+                )
+                graph.value_info.append(add_const)
+                model.set_initializer(add_const.name, B)
+                # create Mul and Add nodes to replace the batchnorm
+                mul_node = oh.make_node(
+                    "Mul", [bn_input, mul_const.name], [mul_output.name]
+                )
+                add_node = oh.make_node(
+                    "Add", [mul_output.name, add_const.name], [bn_output]
+                )
+                # insert where the batchnorm is to preserve topological ordering
+                graph.node.insert(node_ind, mul_node)
+                graph.node.insert(node_ind + 1, add_node)
+                # remove old nodes
+                graph.node.remove(n)
+                if consumer is not None:
+                    graph.node.remove(consumer)
+                if producer is not None:
+                    graph.node.remove(producer)
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
diff --git a/src/finn/transformation/fold_constants.py b/src/finn/transformation/fold_constants.py
index a9331d306..5b27d906c 100644
--- a/src/finn/transformation/fold_constants.py
+++ b/src/finn/transformation/fold_constants.py
@@ -1,31 +1,34 @@
 import finn.core.onnx_exec as oxe
-import finn.transformation.infer_shapes as si
+from finn.transformation import Transformation
+from finn.transformation.infer_shapes import InferShapes
 
 
-def fold_constants(model):
+class FoldConstants(Transformation):
     """Replace the output of a node with const-only inputs with a precomputed
     result."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    execution_context = model.make_empty_exec_context()
-    for n in graph.node:
-        node_ind += 1
-        node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input))
-        node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits))
-        node_out = n.output[0]
-        is_all_constant_inputs = len(node_inp_dyn) == 0
-        ishape = model.get_tensor_shape(n.input[0])
-        is_const_shape = (n.op_type == "Shape") and (ishape is not None)
-        if is_all_constant_inputs or is_const_shape:
-            # this node has no dynamic inputs, only constant ones -- so we can
-            # do constant folding.
-            oxe.execute_node(n, execution_context, graph)
-            # use the execution result as an initializer
-            model.set_initializer(node_out, execution_context[node_out])
-            # remove old node
-            graph.node.remove(n)
-            graph_modified = True
-    if graph_modified:
-        model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        execution_context = model.make_empty_exec_context()
+        for n in graph.node:
+            node_ind += 1
+            node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input))
+            node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits))
+            node_out = n.output[0]
+            is_all_constant_inputs = len(node_inp_dyn) == 0
+            ishape = model.get_tensor_shape(n.input[0])
+            is_const_shape = (n.op_type == "Shape") and (ishape is not None)
+            if is_all_constant_inputs or is_const_shape:
+                # this node has no dynamic inputs, only constant ones -- so we can
+                # do constant folding.
+                oxe.execute_node(n, execution_context, graph)
+                # use the execution result as an initializer
+                model.set_initializer(node_out, execution_context[node_out])
+                # remove old node
+                graph.node.remove(n)
+                graph_modified = True
+        if graph_modified:
+            model = model.transform(InferShapes())
+        return (model, graph_modified)
diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index 44a4614c6..b6845312b 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -1,59 +1,68 @@
 import finn.core.utils as util
+from finn.transformation import Transformation
 
 
-def give_unique_node_names(model):
+class GiveUniqueNodeNames(Transformation):
     """Give unique names to each node in the graph using enumeration."""
-    optype_count = {}
-    for n in model.graph.node:
-        if n.op_type not in optype_count.keys():
-            optype_count[n.op_type] = 0
-        n.name = "%s_%d" % (n.op_type, optype_count[n.op_type])
-        optype_count[n.op_type] += 1
-    # return model_was_changed = False as single iteration is always enough
-    return (model, False)
 
+    def apply(self, model):
+        optype_count = {}
+        for n in model.graph.node:
+            if n.op_type not in optype_count.keys():
+                optype_count[n.op_type] = 0
+            n.name = "%s_%d" % (n.op_type, optype_count[n.op_type])
+            optype_count[n.op_type] += 1
+        # return model_was_changed = False as single iteration is always enough
+        return (model, False)
 
-def give_random_tensor_names(model):
+
+class GiveRandomTensorNames(Transformation):
     """Give random tensor names to all tensors."""
-    names = model.get_all_tensor_names()
-    for name in names:
-        model.rename_tensor(name, util.random_string())
-    # return model_was_changed = False as single iteration is always enough
-    return (model, False)
+
+    def apply(self, model):
+        names = model.get_all_tensor_names()
+        for name in names:
+            model.rename_tensor(name, util.random_string())
+        # return model_was_changed = False as single iteration is always enough
+        return (model, False)
 
 
-def give_readable_tensor_names(model):
+class GiveReadableTensorNames(Transformation):
     """Give more human-readable names to all internal tensors. It's recommended
     to apply give_unique_node_names prior to this transform."""
-    # to ensure we can use rename_tensor safely (without renaming existing
-    # tensors) we start by giving random names to all tensors
-    model = model.transform_single(give_random_tensor_names)
-    graph = model.graph
-    for n in graph.node:
-        out_num = 0
-        for o in n.output:
-            model.rename_tensor(o, "%s_out%d" % (n.name, out_num))
-            out_num += 1
-        init_in_num = 0
-        for i in n.input:
-            if model.get_initializer(i) is not None:
-                model.rename_tensor(i, "%s_param%d" % (n.name, init_in_num))
-                init_in_num += 1
-    # give special names to the main model input and output
-    model.rename_tensor(model.graph.input[0].name, "global_in")
-    model.rename_tensor(model.graph.output[0].name, "global_out")
-    # return model_was_changed = False as single iteration is always enough
-    return (model, False)
-
-
-def convert_sub_to_add(model):
+
+    def apply(self, model):
+        # to ensure we can use rename_tensor safely (without renaming existing
+        # tensors) we start by giving random names to all tensors
+        model = model.transform(GiveRandomTensorNames())
+        graph = model.graph
+        for n in graph.node:
+            out_num = 0
+            for o in n.output:
+                model.rename_tensor(o, "%s_out%d" % (n.name, out_num))
+                out_num += 1
+            init_in_num = 0
+            for i in n.input:
+                if model.get_initializer(i) is not None:
+                    model.rename_tensor(i, "%s_param%d" % (n.name, init_in_num))
+                    init_in_num += 1
+        # give special names to the main model input and output
+        model.rename_tensor(model.graph.input[0].name, "global_in")
+        model.rename_tensor(model.graph.output[0].name, "global_out")
+        # return model_was_changed = False as single iteration is always enough
+        return (model, False)
+
+
+class ConvertSubToAdd(Transformation):
     """Convert sub nodes to add nodes of appropriate sign."""
-    graph = model.graph
-    for n in graph.node:
-        if n.op_type == "Sub":
-            A = model.get_initializer(n.input[1])
-            if A is not None:
-                n.op_type = "Add"
-                model.set_initializer(n.input[1], -A)
-    # return model_was_changed = False as single iteration is always enough
-    return (model, False)
+
+    def apply(self, model):
+        graph = model.graph
+        for n in graph.node:
+            if n.op_type == "Sub":
+                A = model.get_initializer(n.input[1])
+                if A is not None:
+                    n.op_type = "Add"
+                    model.set_initializer(n.input[1], -A)
+        # return model_was_changed = False as single iteration is always enough
+        return (model, False)
diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 19d947b57..a311012fc 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -1,4 +1,5 @@
 from finn.core.datatype import DataType
+from finn.transformation import Transformation
 
 
 def _infer_node_datatype(model, node):
@@ -37,11 +38,13 @@ def _infer_node_datatype(model, node):
     return graph_modified
 
 
-def infer_datatypes(model):
+class InferDataTypes(Transformation):
     """Infer FINN DataType info for all intermediate/output tensors based on
-  inputs and node type."""
-    graph = model.graph
-    graph_modified = False
-    for node in graph.node:
-        graph_modified |= _infer_node_datatype(model, node)
-    return (model, graph_modified)
+    inputs and node type."""
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        for node in graph.node:
+            graph_modified |= _infer_node_datatype(model, node)
+        return (model, graph_modified)
diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py
index 44ec42049..3e79587d4 100644
--- a/src/finn/transformation/infer_shapes.py
+++ b/src/finn/transformation/infer_shapes.py
@@ -2,6 +2,7 @@ import onnx.helper as helper
 import onnx.shape_inference as si
 
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation import Transformation
 
 
 def _make_shape_compatible_op(node):
@@ -44,12 +45,14 @@ def _restore_finn_ops(model, hidden_ops):
             pass
 
 
-def infer_shapes(model):
+class InferShapes(Transformation):
     """Ensure every tensor in the model has a specified shape (ValueInfo)."""
-    # hide your riches!
-    hidden_ops = _hide_finn_ops(model)
-    # call regular ONNX shape inference
-    model = ModelWrapper(si.infer_shapes(model.model))
-    # bring back hidden ops
-    _restore_finn_ops(model, hidden_ops)
-    return (model, False)
+
+    def apply(self, model):
+        # hide your riches!
+        hidden_ops = _hide_finn_ops(model)
+        # call regular ONNX shape inference
+        model = ModelWrapper(si.infer_shapes(model.model))
+        # bring back hidden ops
+        _restore_finn_ops(model, hidden_ops)
+        return (model, False)
diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 5053b23c3..fb9e530d6 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -1,384 +1,416 @@
 import numpy as np
 from onnx import helper as oh
 
-import finn.transformation.infer_shapes as si
 from finn.core.datatype import DataType
+from finn.transformation import Transformation
+from finn.transformation.infer_shapes import InferShapes
 
 
-def convert_sign_to_thres(model):
+class ConvertSignToThres(Transformation):
     """Convert Sign node instances to MultiThreshold with threshold at 0."""
-    graph = model.graph
-    graph_modified = False
-    node_ind = 0
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Sign":
-            sign_out_name = n.output[0]
-            # find consumer
-            consumer = model.find_consumer(sign_out_name)
-            assert consumer is not None
-            # change op type and create threshold
-            n.op_type = "MultiThreshold"
-            thres_param_name = model.make_new_valueinfo_name()
-            thres_param = np.asarray([[0]], dtype=np.float32)
-            n.input.append(thres_param_name)
-            n.domain = "finn"
-            model.set_initializer(thres_param_name, thres_param)
-            # convert 0,1 -> -1,+1 with 2*x-1
-            out_shape = model.get_tensor_shape(sign_out_name)
-            # make a mul node
-            # note how set_initializer or set_tensor_shape is called before
-            # calling make_new_valueinfo_name again
-            mul_param_name = model.make_new_valueinfo_name()
-            model.set_initializer(mul_param_name, np.asarray([[2]], dtype=np.float32))
-            mul_out_name = model.make_new_valueinfo_name()
-            model.set_tensor_shape(mul_out_name, out_shape)
-            mul_node = oh.make_node(
-                "Mul", [sign_out_name, mul_param_name], [mul_out_name]
-            )
-            # make an add node
-            add_param_name = model.make_new_valueinfo_name()
-            model.set_initializer(add_param_name, np.asarray([[-1]], dtype=np.float32))
-            add_out_name = model.make_new_valueinfo_name()
-            model.set_tensor_shape(add_out_name, out_shape)
-            add_node = oh.make_node(
-                "Add", [mul_out_name, add_param_name], [add_out_name]
-            )
-            # add new nodes to graph at correct position
-            graph.node.insert(node_ind, mul_node)
-            graph.node.insert(node_ind + 1, add_node)
-            # rewrite consumer's input
-            consumer.input[0] = add_out_name
-            # add quantization annotations
-            model.set_tensor_datatype(sign_out_name, DataType.BINARY)
-            model.set_tensor_datatype(mul_out_name, DataType.UINT2)
-            model.set_tensor_datatype(add_out_name, DataType.BIPOLAR)
-            graph_modified = True
-    return (model, graph_modified)
-
-
-def collapse_repeated_op(model, op_name, make_collapsed_param_fxn):
-    """Collapse repeated consecutive operations with constant parameters into
-    a single operation. make_collapsed_param_fxn must take two tensors and
-    return a tensor which gives the equivalent result using a single op. """
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == op_name:
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == op_name:
-                op0_param_name = n.input[1]
-                op1_param_name = consumer.input[1]
-                op0_param = model.get_initializer(op0_param_name)
-                op1_param = model.get_initializer(op1_param_name)
-                assert op0_param is not None
-                assert op1_param is not None
-                start_name = n.input[0]
-                end_name = consumer.output[0]
-                # compute the new parameter
-                new_param = make_collapsed_param_fxn(op0_param, op1_param)
-                # make and insert new node
-                new_node_param_name = op0_param_name
-                new_node = oh.make_node(
-                    op_name, [start_name, new_node_param_name], [end_name]
-                )
-                graph.node.insert(node_ind, new_node)
-                # replace parameter value
-                model.set_initializer(new_node_param_name, new_param)
-                # remove old nodes
-                graph.node.remove(n)
-                graph.node.remove(consumer)
-                graph_modified = True
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
-
 
-def collapse_repeated_add(model):
-    return collapse_repeated_op(model, "Add", lambda x, y: y + x)
-
-
-def collapse_repeated_mul(model):
-    return collapse_repeated_op(model, "Mul", lambda x, y: y * x)
-
-
-def move_add_past_mul(model):
-    """Move add operations past multiply operations. The aim is to have them
-    next to each other such that they can be collapsed into a single add."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Add":
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "Mul":
-                # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
-                # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
-                # assume input 0 is from the previous layer, input 1 is the
-                # trained (constant) parameter
-                mul_weight_name = consumer.input[1]
-                add_weight_name = n.input[1]
-                A = model.get_initializer(mul_weight_name)
-                B = model.get_initializer(add_weight_name)
-                assert A is not None
-                assert B is not None
-                start_name = n.input[0]
-                middle_name = n.output[0]
-                end_name = consumer.output[0]
-                # compute new param value for add
-                BA = B * A
-                # make and insert new nodes
-                new_mul = oh.make_node(
-                    "Mul", [start_name, mul_weight_name], [middle_name]
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Sign":
+                sign_out_name = n.output[0]
+                # find consumer
+                consumer = model.find_consumer(sign_out_name)
+                assert consumer is not None
+                # change op type and create threshold
+                n.op_type = "MultiThreshold"
+                thres_param_name = model.make_new_valueinfo_name()
+                thres_param = np.asarray([[0]], dtype=np.float32)
+                n.input.append(thres_param_name)
+                n.domain = "finn"
+                model.set_initializer(thres_param_name, thres_param)
+                # convert 0,1 -> -1,+1 with 2*x-1
+                out_shape = model.get_tensor_shape(sign_out_name)
+                # make a mul node
+                # note how set_initializer or set_tensor_shape is called before
+                # calling make_new_valueinfo_name again
+                mul_param_name = model.make_new_valueinfo_name()
+                model.set_initializer(
+                    mul_param_name, np.asarray([[2]], dtype=np.float32)
+                )
+                mul_out_name = model.make_new_valueinfo_name()
+                model.set_tensor_shape(mul_out_name, out_shape)
+                mul_node = oh.make_node(
+                    "Mul", [sign_out_name, mul_param_name], [mul_out_name]
                 )
-                new_add = oh.make_node(
-                    "Add", [middle_name, add_weight_name], [end_name]
+                # make an add node
+                add_param_name = model.make_new_valueinfo_name()
+                model.set_initializer(
+                    add_param_name, np.asarray([[-1]], dtype=np.float32)
                 )
-                graph.node.insert(node_ind, new_mul)
-                graph.node.insert(node_ind + 1, new_add)
-                # replace add value
-                model.set_initializer(add_weight_name, BA)
-                # remove old nodes
-                graph.node.remove(n)
-                graph.node.remove(consumer)
+                add_out_name = model.make_new_valueinfo_name()
+                model.set_tensor_shape(add_out_name, out_shape)
+                add_node = oh.make_node(
+                    "Add", [mul_out_name, add_param_name], [add_out_name]
+                )
+                # add new nodes to graph at correct position
+                graph.node.insert(node_ind, mul_node)
+                graph.node.insert(node_ind + 1, add_node)
+                # rewrite consumer's input
+                consumer.input[0] = add_out_name
+                # add quantization annotations
+                model.set_tensor_datatype(sign_out_name, DataType.BINARY)
+                model.set_tensor_datatype(mul_out_name, DataType.UINT2)
+                model.set_tensor_datatype(add_out_name, DataType.BIPOLAR)
                 graph_modified = True
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+        return (model, graph_modified)
 
 
-def move_scalar_mul_past_matmul(model):
-    """Move scalar mul operations past matmul operations. We want to have muls
-    next to each other such that they can be collapsed into a single mul."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Mul":
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "MatMul":
-                mul_weight_name = n.input[1]
-                matmul_weight_name = consumer.input[1]
-                A = model.get_initializer(mul_weight_name)
-                W = model.get_initializer(matmul_weight_name)
-                assert A is not None
-                assert W is not None
-                start_name = n.input[0]
-                middle_name = n.output[0]
-                end_name = consumer.output[0]
-                mm_out_shape = model.get_tensor_shape(end_name)
-                if all(x == 1 for x in A.shape):
-                    # if the mul is scalar, we can simply swap the order of ops
-                    # make and insert new nodes
-                    new_matmul = oh.make_node(
-                        "MatMul", [start_name, matmul_weight_name], [middle_name]
-                    )
-                    new_mul = oh.make_node(
-                        "Mul", [middle_name, mul_weight_name], [end_name]
+class CollapseRepeatedOp(Transformation):
+    """Collapse repeated consecutive operations with constant parameters into
+    a single operation. make_collapsed_param_fxn must take two tensors and
+    return a tensor which gives the equivalent result using a single op. """
+
+    def __init__(self, op_name, make_collapsed_param_fxn):
+        super().__init__()
+        self.op_name = op_name
+        self.make_collapsed_param_fxn = make_collapsed_param_fxn
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == self.op_name:
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == self.op_name:
+                    op0_param_name = n.input[1]
+                    op1_param_name = consumer.input[1]
+                    op0_param = model.get_initializer(op0_param_name)
+                    op1_param = model.get_initializer(op1_param_name)
+                    assert op0_param is not None
+                    assert op1_param is not None
+                    start_name = n.input[0]
+                    end_name = consumer.output[0]
+                    # compute the new parameter
+                    new_param = self.make_collapsed_param_fxn(op0_param, op1_param)
+                    # make and insert new node
+                    new_node_param_name = op0_param_name
+                    new_node = oh.make_node(
+                        self.op_name, [start_name, new_node_param_name], [end_name]
                     )
-                    graph.node.insert(node_ind, new_matmul)
-                    graph.node.insert(node_ind + 1, new_mul)
-                    model.set_tensor_shape(middle_name, mm_out_shape)
+                    graph.node.insert(node_ind, new_node)
+                    # replace parameter value
+                    model.set_initializer(new_node_param_name, new_param)
                     # remove old nodes
                     graph.node.remove(n)
                     graph.node.remove(consumer)
                     graph_modified = True
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
 
 
-def move_scalar_add_past_matmul(model):
-    """Move scalar add operations past matmul operations. We want to have adds
+class CollapseRepeatedAdd(CollapseRepeatedOp):
+    def __init__(self):
+        super().__init__("Add", lambda x, y: y + x)
+
+
+class CollapseRepeatedMul(CollapseRepeatedOp):
+    def __init__(self):
+        super().__init__("Mul", lambda x, y: y * x)
+
+
+class MoveAddPastMul(Transformation):
+    """Move add operations past multiply operations. The aim is to have them
     next to each other such that they can be collapsed into a single add."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Add":
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "MatMul":
-                add_weight_name = n.input[1]
-                matmul_weight_name = consumer.input[1]
-                A = model.get_initializer(add_weight_name)
-                W = model.get_initializer(matmul_weight_name)
-                assert A is not None
-                assert W is not None
-                start_name = n.input[0]
-                middle_name = n.output[0]
-                end_name = consumer.output[0]
-                mm_out_shape = model.get_tensor_shape(end_name)
-                if all(x == 1 for x in A.shape):
-                    # if the add is scalar, we can move it past the matmul
-                    # by taking it past the matmul with a dot product
-                    Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W)
-                    # update the add weight
-                    model.set_initializer(add_weight_name, Anew)
-                    new_matmul = oh.make_node(
-                        "MatMul", [start_name, matmul_weight_name], [middle_name]
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Add":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "Mul":
+                    # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
+                    # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
+                    # assume input 0 is from the previous layer, input 1 is the
+                    # trained (constant) parameter
+                    mul_weight_name = consumer.input[1]
+                    add_weight_name = n.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    B = model.get_initializer(add_weight_name)
+                    assert A is not None
+                    assert B is not None
+                    start_name = n.input[0]
+                    middle_name = n.output[0]
+                    end_name = consumer.output[0]
+                    # compute new param value for add
+                    BA = B * A
+                    # make and insert new nodes
+                    new_mul = oh.make_node(
+                        "Mul", [start_name, mul_weight_name], [middle_name]
                     )
                     new_add = oh.make_node(
                         "Add", [middle_name, add_weight_name], [end_name]
                     )
-                    graph.node.insert(node_ind, new_matmul)
+                    graph.node.insert(node_ind, new_mul)
                     graph.node.insert(node_ind + 1, new_add)
-                    model.set_tensor_shape(middle_name, mm_out_shape)
+                    # replace add value
+                    model.set_initializer(add_weight_name, BA)
                     # remove old nodes
                     graph.node.remove(n)
                     graph.node.remove(consumer)
                     graph_modified = True
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
 
 
-def absorb_add_into_multi_threshold(model):
+class MoveScalarMulPastMatMul(Transformation):
+    """Move scalar mul operations past matmul operations. We want to have muls
+    next to each other such that they can be collapsed into a single mul."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Mul":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MatMul":
+                    mul_weight_name = n.input[1]
+                    matmul_weight_name = consumer.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    W = model.get_initializer(matmul_weight_name)
+                    assert A is not None
+                    assert W is not None
+                    start_name = n.input[0]
+                    middle_name = n.output[0]
+                    end_name = consumer.output[0]
+                    mm_out_shape = model.get_tensor_shape(end_name)
+                    if all(x == 1 for x in A.shape):
+                        # if the mul is scalar, we can simply swap the order of ops
+                        # make and insert new nodes
+                        new_matmul = oh.make_node(
+                            "MatMul", [start_name, matmul_weight_name], [middle_name]
+                        )
+                        new_mul = oh.make_node(
+                            "Mul", [middle_name, mul_weight_name], [end_name]
+                        )
+                        graph.node.insert(node_ind, new_matmul)
+                        graph.node.insert(node_ind + 1, new_mul)
+                        model.set_tensor_shape(middle_name, mm_out_shape)
+                        # remove old nodes
+                        graph.node.remove(n)
+                        graph.node.remove(consumer)
+                        graph_modified = True
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
+class MoveScalarAddPastMatMul(Transformation):
+    """Move scalar add operations past matmul operations. We want to have adds
+    next to each other such that they can be collapsed into a single add."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Add":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MatMul":
+                    add_weight_name = n.input[1]
+                    matmul_weight_name = consumer.input[1]
+                    A = model.get_initializer(add_weight_name)
+                    W = model.get_initializer(matmul_weight_name)
+                    assert A is not None
+                    assert W is not None
+                    start_name = n.input[0]
+                    middle_name = n.output[0]
+                    end_name = consumer.output[0]
+                    mm_out_shape = model.get_tensor_shape(end_name)
+                    if all(x == 1 for x in A.shape):
+                        # if the add is scalar, we can move it past the matmul
+                        # by taking it past the matmul with a dot product
+                        Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W)
+                        # update the add weight
+                        model.set_initializer(add_weight_name, Anew)
+                        new_matmul = oh.make_node(
+                            "MatMul", [start_name, matmul_weight_name], [middle_name]
+                        )
+                        new_add = oh.make_node(
+                            "Add", [middle_name, add_weight_name], [end_name]
+                        )
+                        graph.node.insert(node_ind, new_matmul)
+                        graph.node.insert(node_ind + 1, new_add)
+                        model.set_tensor_shape(middle_name, mm_out_shape)
+                        # remove old nodes
+                        graph.node.remove(n)
+                        graph.node.remove(consumer)
+                        graph_modified = True
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
+class AbsorbAddIntoMultiThreshold(Transformation):
     """Absorb preceding Add ops into MultiThreshold by updating the threshold
     values."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Add":
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "MultiThreshold":
-                add_weight_name = n.input[1]
-                threshold_name = consumer.input[1]
-                A = model.get_initializer(add_weight_name)
-                T = model.get_initializer(threshold_name)
-                assert A is not None
-                assert T is not None
-                start_name = n.input[0]
-                # compute new thresholds and set initializer
-                Tnew = T - A.reshape(-1, T.shape[1])
-                model.set_initializer(threshold_name, Tnew)
-                # wire add input directly to MultiThreshold
-                consumer.input[0] = start_name
-                # remove the add node
-                graph.node.remove(n)
-                graph_modified = True
-    return (model, graph_modified)
 
-
-def absorb_mul_into_multi_threshold(model):
-    """Absorb preceding Mul ops into MultiThreshold by updating the threshold
-    values. Only *positive* scalar/1D vectors can be absorbed."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Mul":
-            mul_weight_name = n.input[1]
-            A = model.get_initializer(mul_weight_name)
-            assert A is not None
-            is_signed = (A < 0).any()
-            is_scalar = np.prod(A.shape) == 1
-            is_1d = len(A.shape) == 2 and A.shape[0] == 1
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "MultiThreshold":
-                if not is_signed and (is_1d or is_scalar):
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Add":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MultiThreshold":
+                    add_weight_name = n.input[1]
                     threshold_name = consumer.input[1]
+                    A = model.get_initializer(add_weight_name)
                     T = model.get_initializer(threshold_name)
+                    assert A is not None
                     assert T is not None
                     start_name = n.input[0]
                     # compute new thresholds and set initializer
-                    Tnew = T / A.reshape(-1, T.shape[1])
-                    # TODO: need to handle negative A values correctly; produce
-                    # mul sign mask and merge into preceding matmul?
+                    Tnew = T - A.reshape(-1, T.shape[1])
                     model.set_initializer(threshold_name, Tnew)
                     # wire add input directly to MultiThreshold
                     consumer.input[0] = start_name
-                    # remove the mul node
+                    # remove the add node
                     graph.node.remove(n)
                     graph_modified = True
-    return (model, graph_modified)
+        return (model, graph_modified)
+
+
+class AbsorbMulIntoMultiThreshold(Transformation):
+    """Absorb preceding Mul ops into MultiThreshold by updating the threshold
+    values. Only *positive* scalar/1D vectors can be absorbed."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Mul":
+                mul_weight_name = n.input[1]
+                A = model.get_initializer(mul_weight_name)
+                assert A is not None
+                is_signed = (A < 0).any()
+                is_scalar = np.prod(A.shape) == 1
+                is_1d = len(A.shape) == 2 and A.shape[0] == 1
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MultiThreshold":
+                    if not is_signed and (is_1d or is_scalar):
+                        threshold_name = consumer.input[1]
+                        T = model.get_initializer(threshold_name)
+                        assert T is not None
+                        start_name = n.input[0]
+                        # compute new thresholds and set initializer
+                        Tnew = T / A.reshape(-1, T.shape[1])
+                        # TODO: need to handle negative A values correctly; produce
+                        # mul sign mask and merge into preceding matmul?
+                        model.set_initializer(threshold_name, Tnew)
+                        # wire add input directly to MultiThreshold
+                        consumer.input[0] = start_name
+                        # remove the mul node
+                        graph.node.remove(n)
+                        graph_modified = True
+        return (model, graph_modified)
 
 
-def factor_out_mul_sign_magnitude(model):
+class FactorOutMulSignMagnitude(Transformation):
     """Split multiply-by-constant nodes into two multiply-by-constant nodes,
     where the first node is a bipolar vector (of signs) and the second is a
     vector of magnitudes."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Mul":
-            mul_weight_name = n.input[1]
-            A = model.get_initializer(mul_weight_name)
-            assert A is not None
-            is_scalar = np.prod(A.shape) == 1
-            is_1d = len(A.shape) == 2 and A.shape[0] == 1
-            is_not_bipolar = (
-                model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR
-            )
-            is_signed = (A < 0).any()
-            if is_signed and (is_scalar or is_1d) and is_not_bipolar:
-                start_name = n.input[0]
-                in_shape = model.get_tensor_shape(start_name)
-                middle_name = model.make_new_valueinfo_name()
-                model.set_tensor_shape(middle_name, in_shape)
-                sign_mul_param_name = model.make_new_valueinfo_name()
-                # create new mul node with sign(A) as the operand
-                sgn = np.sign(A)
-                model.set_initializer(sign_mul_param_name, sgn)
-                model.set_tensor_datatype(sign_mul_param_name, DataType.BIPOLAR)
-                # replace original mul weight by magnitudes
-                model.set_initializer(mul_weight_name, np.abs(A))
-                new_mul = oh.make_node(
-                    "Mul", [start_name, sign_mul_param_name], [middle_name]
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Mul":
+                mul_weight_name = n.input[1]
+                A = model.get_initializer(mul_weight_name)
+                assert A is not None
+                is_scalar = np.prod(A.shape) == 1
+                is_1d = len(A.shape) == 2 and A.shape[0] == 1
+                is_not_bipolar = (
+                    model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR
                 )
-                n.input[0] = middle_name
-                graph.node.insert(node_ind - 1, new_mul)
-                graph_modified = True
-    return (model, graph_modified)
+                is_signed = (A < 0).any()
+                if is_signed and (is_scalar or is_1d) and is_not_bipolar:
+                    start_name = n.input[0]
+                    in_shape = model.get_tensor_shape(start_name)
+                    middle_name = model.make_new_valueinfo_name()
+                    model.set_tensor_shape(middle_name, in_shape)
+                    sign_mul_param_name = model.make_new_valueinfo_name()
+                    # create new mul node with sign(A) as the operand
+                    sgn = np.sign(A)
+                    model.set_initializer(sign_mul_param_name, sgn)
+                    model.set_tensor_datatype(sign_mul_param_name, DataType.BIPOLAR)
+                    # replace original mul weight by magnitudes
+                    model.set_initializer(mul_weight_name, np.abs(A))
+                    new_mul = oh.make_node(
+                        "Mul", [start_name, sign_mul_param_name], [middle_name]
+                    )
+                    n.input[0] = middle_name
+                    graph.node.insert(node_ind - 1, new_mul)
+                    graph_modified = True
+        return (model, graph_modified)
 
 
-def absorb_1bit_mul_into_matmul(model):
+class Absorb1BitMulIntoMatMul(Transformation):
     """Absorb bipolar or binary multiplications into the preciding matrix
     multiply."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "MatMul":
-            matmul_weight_name = n.input[1]
-            W = model.get_initializer(matmul_weight_name)
-            assert W is not None
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "Mul":
-                mul_weight_name = consumer.input[1]
-                A = model.get_initializer(mul_weight_name)
-                assert A is not None
-                is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
-                if is_1bit:
-                    Wnew = A * W
-                    assert Wnew.shape == W.shape
-                    model.set_initializer(matmul_weight_name, Wnew)
-                    n.output[0] = consumer.output[0]
-                    graph.node.remove(consumer)
-                    graph_modified = True
-    return (model, graph_modified)
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "MatMul":
+                matmul_weight_name = n.input[1]
+                W = model.get_initializer(matmul_weight_name)
+                assert W is not None
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "Mul":
+                    mul_weight_name = consumer.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    assert A is not None
+                    is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
+                    if is_1bit:
+                        Wnew = A * W
+                        assert Wnew.shape == W.shape
+                        model.set_initializer(matmul_weight_name, Wnew)
+                        n.output[0] = consumer.output[0]
+                        graph.node.remove(consumer)
+                        graph_modified = True
+        return (model, graph_modified)
 
 
-def round_thresholds(model):
+class RoundThresholds(Transformation):
     """For MultiThreshold nodes operating on integer inputs, round up
     thresholds values to the nearest integer."""
-    graph = model.graph
-    graph_modified = False
-    for n in graph.node:
-        if n.op_type == "MultiThreshold":
-            idtype = model.get_tensor_datatype(n.input[0])
-            T = model.get_initializer(n.input[1])
-            Tnew = np.ceil(T)
-            if idtype.is_integer() and (T != Tnew).any():
-                # round up the thresholds to nearest integer
-                model.set_initializer(n.input[1], Tnew)
-                # use same datatype as inputs for thresholds
-                model.set_tensor_datatype(n.input[1], idtype)
-                graph_modified = True
-    return (model, graph_modified)
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        for n in graph.node:
+            if n.op_type == "MultiThreshold":
+                idtype = model.get_tensor_datatype(n.input[0])
+                T = model.get_initializer(n.input[1])
+                Tnew = np.ceil(T)
+                if idtype.is_integer() and (T != Tnew).any():
+                    # round up the thresholds to nearest integer
+                    model.set_initializer(n.input[1], Tnew)
+                    # use same datatype as inputs for thresholds
+                    model.set_tensor_datatype(n.input[1], idtype)
+                    graph_modified = True
+        return (model, graph_modified)
diff --git a/tests/test_basic_onnx_exec.py b/tests/test_basic_onnx_exec.py
index 30e9106fe..c7b3da1b7 100644
--- a/tests/test_basic_onnx_exec.py
+++ b/tests/test_basic_onnx_exec.py
@@ -5,15 +5,15 @@ import onnx
 import onnx.numpy_helper as np_helper
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_mnist_onnx_download_extract_run():
     # load the onnx model
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb")
diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py
index bb66f98f4..dec01b53a 100644
--- a/tests/test_batchnorm_to_affine.py
+++ b/tests/test_batchnorm_to_affine.py
@@ -8,10 +8,10 @@ import torch
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.batchnorm_to_affine as tx
-import finn.transformation.fold_constants as fc
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.batchnorm_to_affine import BatchNormToAffine
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 transformed_onnx_path = "test_output_lfc_transformed.onnx"
@@ -27,9 +27,9 @@ def test_batchnorm_to_affine():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
-    new_model = model.transform_single(tx.batchnorm_to_affine)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    new_model = model.transform(BatchNormToAffine())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index 38ec67175..641e5e3c4 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -9,9 +9,9 @@ import torch
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.fold_constants as fc
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 # TODO get from config instead, hardcoded to Docker path for now
@@ -88,8 +88,8 @@ def test_brevitas_to_onnx_export_and_exec_lfc_w1a1():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
@@ -113,8 +113,8 @@ def test_brevitas_to_onnx_export_and_exec_lfc_w1a2():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
diff --git a/tests/test_collapse_repeated_op.py b/tests/test_collapse_repeated_op.py
index d8cdde3c6..d97cdbc30 100644
--- a/tests/test_collapse_repeated_op.py
+++ b/tests/test_collapse_repeated_op.py
@@ -3,9 +3,9 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.core.onnx_exec as ox
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as tx
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul
 
 
 def test_collapse_repeated_op():
@@ -30,12 +30,12 @@ def test_collapse_repeated_op():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("add_param_0", np.asarray([1, 3], dtype=np.float32))
     model.set_initializer("add_param_1", np.asarray([-1, 3], dtype=np.float32))
     model.set_initializer("mul_param_0", np.asarray([2, 4], dtype=np.float32))
     model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32))
-    new_model = model.transform_repeated(tx.collapse_repeated_add)
-    new_model = new_model.transform_repeated(tx.collapse_repeated_mul)
+    new_model = model.transform(CollapseRepeatedAdd())
+    new_model = new_model.transform(CollapseRepeatedMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
diff --git a/tests/test_factor_out_mul_sign_magnitude.py b/tests/test_factor_out_mul_sign_magnitude.py
index 2b4c0ba93..3786492fe 100644
--- a/tests/test_factor_out_mul_sign_magnitude.py
+++ b/tests/test_factor_out_mul_sign_magnitude.py
@@ -3,9 +3,9 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.core.onnx_exec as ox
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as tx
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import FactorOutMulSignMagnitude
 
 
 def test_factor_out_mul_sign_magnitude():
@@ -22,8 +22,8 @@ def test_factor_out_mul_sign_magnitude():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("mul_param", np.asarray([[-1, 4]], dtype=np.float32))
-    new_model = model.transform_repeated(tx.factor_out_mul_sign_magnitude)
+    new_model = model.transform(FactorOutMulSignMagnitude())
     inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
diff --git a/tests/test_fold_constants.py b/tests/test_fold_constants.py
index 894887b0c..09dbd95c2 100644
--- a/tests/test_fold_constants.py
+++ b/tests/test_fold_constants.py
@@ -8,9 +8,9 @@ import onnx.numpy_helper as np_helper
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.fold_constants as fc
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 
@@ -18,8 +18,8 @@ export_onnx_path = "test_output_lfc.onnx"
 def test_const_folding():
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_single(fc.fold_constants)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
@@ -35,8 +35,8 @@ def test_const_folding_shapes():
     lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
     assert model.graph.node[0].op_type == "Reshape"
     assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28]
     assert list(model.get_tensor_shape("27")) == [1, 784]
diff --git a/tests/test_general_transformation.py b/tests/test_general_transformation.py
index d4376cca6..dd44502ad 100644
--- a/tests/test_general_transformation.py
+++ b/tests/test_general_transformation.py
@@ -1,13 +1,13 @@
 from pkgutil import get_data
 
-import finn.transformation.general as tg
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.general import GiveUniqueNodeNames
 
 
 def test_give_unique_node_names():
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(tg.give_unique_node_names)
+    model = model.transform(GiveUniqueNodeNames())
     assert model.graph.node[0].name == "Reshape_0"
     assert model.graph.node[1].name == "Conv_0"
     assert model.graph.node[11].name == "Add_2"
diff --git a/tests/test_infer_datatypes.py b/tests/test_infer_datatypes.py
index fd1c99a1d..e4269499c 100644
--- a/tests/test_infer_datatypes.py
+++ b/tests/test_infer_datatypes.py
@@ -4,12 +4,12 @@ import brevitas.onnx as bo
 import torch
 from models.LFC import LFC
 
-import finn.transformation.fold_constants as fc
-import finn.transformation.general as tg
-import finn.transformation.infer_datatypes as id
-import finn.transformation.infer_shapes as si
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 # TODO get from config instead, hardcoded to Docker path for now
@@ -24,11 +24,11 @@ def test_infer_datatypes():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
-    model = model.transform_single(tg.give_unique_node_names)
-    model = model.transform_single(tg.give_readable_tensor_names)
-    model = model.transform_repeated(id.infer_datatypes)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
+    model = model.transform(InferDataTypes())
     assert model.get_tensor_datatype("MatMul_0_out0") == DataType.INT32
     assert model.get_tensor_datatype("MatMul_1_out0") == DataType.INT32
     assert model.get_tensor_datatype("MatMul_2_out0") == DataType.INT32
diff --git a/tests/test_infer_shapes.py b/tests/test_infer_shapes.py
index 9ca0c9303..20841b322 100644
--- a/tests/test_infer_shapes.py
+++ b/tests/test_infer_shapes.py
@@ -4,8 +4,8 @@ import numpy as np
 from onnx import TensorProto, helper
 
 import finn.core.utils as util
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_infer_shapes():
@@ -49,7 +49,7 @@ def test_infer_shapes():
     ), "All tensors are already specified before the shape inference execution"
 
     # perform shape inference on mixed model
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
 
     # second check routine
     # now all shapes should be specified and mt_node output shape is (1,8,28,28)
diff --git a/tests/test_is_linear.py b/tests/test_is_linear.py
index 1995604b7..cce4612a9 100644
--- a/tests/test_is_linear.py
+++ b/tests/test_is_linear.py
@@ -2,8 +2,8 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.analysis.topology as ta
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_is_linear_linear():
@@ -24,7 +24,7 @@ def test_is_linear_linear():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.is_linear)
     assert ret["is_linear"] is True
 
@@ -52,6 +52,6 @@ def test_is_linear_forked_node_output():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.is_linear)
     assert ret["is_linear"] is False
diff --git a/tests/test_mixed_onnx_exec.py b/tests/test_mixed_onnx_exec.py
index 0ac3f09da..75170da3a 100644
--- a/tests/test_mixed_onnx_exec.py
+++ b/tests/test_mixed_onnx_exec.py
@@ -2,8 +2,8 @@ import numpy as np
 from onnx import TensorProto, helper
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_execute_mixed_model():
@@ -30,7 +30,7 @@ def test_execute_mixed_model():
     model_def = helper.make_model(graph_def, producer_name="onnx-example")
 
     model = ModelWrapper(model_def)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
 
     inputs = np.asarray(
         [
diff --git a/tests/test_move_add_past_mul.py b/tests/test_move_add_past_mul.py
index a23827d0e..565bbce39 100644
--- a/tests/test_move_add_past_mul.py
+++ b/tests/test_move_add_past_mul.py
@@ -3,9 +3,9 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.core.onnx_exec as ox
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as tx
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import MoveAddPastMul
 
 
 def test_move_add_past_mul_single():
@@ -26,10 +26,10 @@ def test_move_add_past_mul_single():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("add_param", np.asarray([1, 3], dtype=np.float32))
     model.set_initializer("mul_param", np.asarray([2, 4], dtype=np.float32))
-    new_model = model.transform_repeated(tx.move_add_past_mul)
+    new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
 
@@ -56,11 +56,11 @@ def test_move_add_past_mul_multi():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("add_param_0", np.asarray([1, 3], dtype=np.float32))
     model.set_initializer("mul_param_0", np.asarray([2, 4], dtype=np.float32))
     model.set_initializer("add_param_1", np.asarray([-1, 3], dtype=np.float32))
     model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32))
-    new_model = model.transform_repeated(tx.move_add_past_mul)
+    new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
diff --git a/tests/test_move_scalar_past_matmul.py b/tests/test_move_scalar_past_matmul.py
index 7bbdd7dd8..c2771ce94 100644
--- a/tests/test_move_scalar_past_matmul.py
+++ b/tests/test_move_scalar_past_matmul.py
@@ -3,9 +3,12 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.core.onnx_exec as ox
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as tx
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import (
+    MoveScalarAddPastMatMul,
+    MoveScalarMulPastMatMul
+)
 
 
 def test_move_scalar_mul_past_matmul():
@@ -26,12 +29,12 @@ def test_move_scalar_mul_past_matmul():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("mul_param", np.asarray([[3]], dtype=np.float32))
     model.set_initializer(
         "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32)
     )
-    new_model = model.transform_repeated(tx.move_scalar_mul_past_matmul)
+    new_model = model.transform(MoveScalarMulPastMatMul())
     inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
     assert new_model.graph.node[0].op_type == "MatMul"
@@ -57,12 +60,12 @@ def test_move_scalar_add_past_matmul():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("add_param", np.asarray([[3]], dtype=np.float32))
     model.set_initializer(
         "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32)
     )
-    new_model = model.transform_repeated(tx.move_scalar_add_past_matmul)
+    new_model = model.transform(MoveScalarAddPastMatMul())
     inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
     assert new_model.graph.node[0].op_type == "MatMul"
diff --git a/tests/test_renaming.py b/tests/test_renaming.py
index c13f7ee06..aec1c8a10 100644
--- a/tests/test_renaming.py
+++ b/tests/test_renaming.py
@@ -5,18 +5,18 @@ import onnx
 import onnx.numpy_helper as np_helper
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.general as tg
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_renaming():
     # load the onnx model
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_single(tg.give_unique_node_names)
-    model = model.transform_single(tg.give_readable_tensor_names)
+    model = model.transform(InferShapes())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
     # do some basic checks
     assert model.graph.input[0].name == "global_in"
     assert model.graph.output[0].name == "global_out"
@@ -27,8 +27,8 @@ def test_renaming():
     assert model.graph.node[6].name == "Add_1"
     assert model.graph.node[6].input[1] == "Add_1_param0"
     # ensure running renaming twice still yields the same names
-    model = model.transform_single(tg.give_unique_node_names)
-    model = model.transform_single(tg.give_readable_tensor_names)
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
     assert model.graph.node[1].op_type == "Conv"
     assert model.graph.node[1].name == "Conv_0"
     assert model.graph.node[1].input[1] == "Conv_0_param0"
diff --git a/tests/test_round_thresholds.py b/tests/test_round_thresholds.py
index 28add7313..afd1d86f8 100644
--- a/tests/test_round_thresholds.py
+++ b/tests/test_round_thresholds.py
@@ -2,9 +2,9 @@ import numpy as np
 from onnx import TensorProto, helper
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.streamline as sl
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.streamline import RoundThresholds
 
 
 def test_round_thresholds():
@@ -27,7 +27,7 @@ def test_round_thresholds():
     orig_n = oxe.execute_onnx(model, inp_dict_n)["out"]
     orig_c = oxe.execute_onnx(model, inp_dict_c)["out"]
     assert model.get_tensor_datatype("thresholds") == DataType.FLOAT32
-    new_model = model.transform_repeated(sl.round_thresholds)
+    new_model = model.transform(RoundThresholds())
     # rounded up thresholds should have same dtype as input
     assert new_model.get_tensor_datatype("thresholds") == DataType.INT8
     new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"]
diff --git a/tests/test_sign_to_thres.py b/tests/test_sign_to_thres.py
index 4e18822cf..75327df3e 100644
--- a/tests/test_sign_to_thres.py
+++ b/tests/test_sign_to_thres.py
@@ -8,10 +8,10 @@ import torch
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.fold_constants as fc
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as sl
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import ConvertSignToThres
 
 export_onnx_path = "test_output_lfc.onnx"
 transformed_onnx_path = "test_output_lfc_transformed.onnx"
@@ -27,9 +27,9 @@ def test_sign_to_thres():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
-    new_model = model.transform_single(sl.convert_sign_to_thres)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    new_model = model.transform(ConvertSignToThres())
     assert new_model.graph.node[3].op_type == "MultiThreshold"
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
diff --git a/tests/test_streamline.py b/tests/test_streamline.py
index 1a93c6ff1..75b030107 100644
--- a/tests/test_streamline.py
+++ b/tests/test_streamline.py
@@ -9,13 +9,17 @@ import torch
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.batchnorm_to_affine as ba
-import finn.transformation.fold_constants as fc
-import finn.transformation.general as tg
-import finn.transformation.infer_datatypes as di
-import finn.transformation.infer_shapes as si
 import finn.transformation.streamline as sl
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.batchnorm_to_affine import BatchNormToAffine
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.general import (
+    ConvertSubToAdd,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames
+)
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 # TODO get from config instead, hardcoded to Docker path for now
@@ -30,10 +34,10 @@ def test_streamline_lfc_w1a1():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
-    model = model.transform_single(tg.give_unique_node_names)
-    model = model.transform_single(tg.give_readable_tensor_names)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
@@ -42,26 +46,26 @@ def test_streamline_lfc_w1a1():
     expected_ctx = oxe.execute_onnx(model, input_dict, True)
     expected = expected_ctx[model.graph.output[0].name]
     transforms = [
-        tg.convert_sub_to_add,
-        ba.batchnorm_to_affine,
-        sl.convert_sign_to_thres,
-        sl.move_scalar_add_past_matmul,
-        sl.move_scalar_mul_past_matmul,
-        sl.move_add_past_mul,
-        sl.collapse_repeated_add,
-        sl.collapse_repeated_mul,
-        sl.absorb_add_into_multi_threshold,
-        sl.factor_out_mul_sign_magnitude,
-        sl.absorb_mul_into_multi_threshold,
-        sl.absorb_1bit_mul_into_matmul,
-        sl.round_thresholds,
+        ConvertSubToAdd(),
+        BatchNormToAffine(),
+        sl.ConvertSignToThres(),
+        sl.MoveScalarAddPastMatMul(),
+        sl.MoveScalarMulPastMatMul(),
+        sl.MoveAddPastMul(),
+        sl.CollapseRepeatedAdd(),
+        sl.CollapseRepeatedMul(),
+        sl.AbsorbAddIntoMultiThreshold(),
+        sl.FactorOutMulSignMagnitude(),
+        sl.AbsorbMulIntoMultiThreshold(),
+        sl.Absorb1BitMulIntoMatMul(),
+        sl.RoundThresholds(),
     ]
     trn_ind = 0
     for trn in transforms:
-        model = model.transform_repeated(trn)
-        model = model.transform_single(tg.give_unique_node_names)
-        model = model.transform_single(tg.give_readable_tensor_names)
-        model = model.transform_repeated(di.infer_datatypes)
+        model = model.transform(trn)
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(GiveReadableTensorNames())
+        model = model.transform(InferDataTypes())
         produced_ctx = oxe.execute_onnx(model, input_dict, True)
         produced = produced_ctx[model.graph.output[0].name]
         # model.save("%d-%s.onnx" % (trn_ind, trn.__name__))
diff --git a/tests/test_topology_checks.py b/tests/test_topology_checks.py
index 2e9582a8d..e28ceac09 100644
--- a/tests/test_topology_checks.py
+++ b/tests/test_topology_checks.py
@@ -4,8 +4,8 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.analysis.topology as ta
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_all_tensors_f32():
@@ -26,7 +26,7 @@ def test_all_tensors_f32():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.all_tensors_f32)
     assert ret["all_tensors_f32"] is True
 
@@ -47,7 +47,7 @@ def test_all_tensors_f32():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.all_tensors_f32)
     assert ret["all_tensors_f32"] is False
 
@@ -55,7 +55,7 @@ def test_all_tensors_f32():
 def test_node_inputs_in_expected_order():
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.node_inputs_in_expected_order)
     # this model has an (unnecessary) dynamic reshape for its weight tensor
     # and so it fails the check
-- 
GitLab