diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index 8a2aa66e702593c832ccf6c45dd3c2804cc14db9..c734413b7418a4d88e5e51188348a27526e861ba 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -54,28 +54,19 @@ class ModelWrapper:
         """Run given anaylsis_fxn on this model and return resulting dict."""
         return analysis_fxn(self)
 
-    def transform_repeated(self, transform, make_deepcopy=True):
-        """Applies given transform repeatedly until no more changes can be made
+    def transform(self, transformation, make_deepcopy=True):
+        """Applies given Transformation repeatedly until no more changes can be made
         and returns a transformed ModelWrapper instance.
         If make_deepcopy is specified, operates on a new (deep)copy of model.
-        Transform must return (transformed_model, model_was_changed)."""
+        """
         transformed_model = self
         if make_deepcopy:
             transformed_model = copy.deepcopy(self)
         model_was_changed = True
         while model_was_changed:
-            (transformed_model, model_was_changed) = transform(transformed_model)
-        return transformed_model
-
-    def transform_single(self, transform, make_deepcopy=True):
-        """Applies given transform once and returns transformed ModelWrapper
-        instance. If make_deepcopy is specified, operates on a new (deep)copy of
-        model. Transform must return (transformed_model, model_was_changed),
-        although model_was_changed is ignored (see also apply_repeated)."""
-        transformed_model = self
-        if make_deepcopy:
-            transformed_model = copy.deepcopy(self)
-        (transformed_model, model_was_changed) = transform(transformed_model)
+            (transformed_model, model_was_changed) = transformation.apply(
+                transformed_model
+            )
         return transformed_model
 
     def check_compatibility(self):
diff --git a/src/finn/transformation/__init__.py b/src/finn/transformation/__init__.py
index 727b8d32a60793c9f32df2df5120b069fb6528dc..3ddce04c11db01b2f6722bc843f0107621630936 100644
--- a/src/finn/transformation/__init__.py
+++ b/src/finn/transformation/__init__.py
@@ -2,9 +2,10 @@
 Guide to writing FINN transformations
 -------------------------------------
 
-* Your transformation should take in a ModelWrapper, and return a tuple with
- (transformed_model: ModelWrapper, model_was_changed: Bool)
-* The transformations are meant to be applied using the .transform functions
+* Your transformation must inherit the Transformation abstract base class.
+* Your transformation's apply function should take in a ModelWrapper, and return
+  a tuple with (transformed_model: ModelWrapper, model_was_changed: Bool)
+* The transformations are meant to be applied using the .transform function
   in ModelWrapper. This makes a deep copy of the input model by default, so
   you don't have to.
 * model_was_changed indicates whether your transformation made any changes to
@@ -14,6 +15,17 @@ Guide to writing FINN transformations
 * You MUST return model_was_changed=False at some point when your transformation
   is called multiple times, otherwise apply_repeated() will loop infinitely.
 * If you cannot guarantee that the transformation will reach a fixed point,
-  you must declare this and notify the user to use .transform_single() instead
-  of .transform_repeated()
+  you must declare this, return model_was_changed = False and let the user
+  manually re-apply the transform.
 """
+
+from abc import ABC, abstractmethod
+
+
+class Transformation(ABC):
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def apply(self, model):
+        pass
diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py
index 446e27ba50637fc0ef961f23b26f85bc8235b4fc..655ddd9842f37d59155aa0b12edeffecd89d65c1 100644
--- a/src/finn/transformation/batchnorm_to_affine.py
+++ b/src/finn/transformation/batchnorm_to_affine.py
@@ -2,70 +2,73 @@ import numpy as np
 from onnx import TensorProto
 from onnx import helper as oh
 
-import finn.transformation.infer_shapes as si
+from finn.transformation import Transformation
+from finn.transformation.infer_shapes import InferShapes
 
 
-def batchnorm_to_affine(model):
+class BatchNormToAffine(Transformation):
     """Replaces any test-time BatchNorm layers with Mul-Add layers."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "BatchNormalization":
-            graph_modified = True
-            bn_input = n.input[0]
-            bn_output = n.output[0]
-            # extract batchnorm parameters as numpy arrays
-            scale = model.get_initializer(n.input[1])
-            bias = model.get_initializer(n.input[2])
-            mean = model.get_initializer(n.input[3])
-            variance = model.get_initializer(n.input[4])
-            epsilon = 1e-5
-            # find A and B to compute batchnorm as affine transpose Ax+B
-            # TODO is a division by moving avg factor needed for variance?
-            A = scale / np.sqrt(epsilon + variance)
-            B = bias - (A * mean)
-            # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
-            producer = model.find_producer(bn_input)
-            if producer is not None:
-                if producer.op_type == "Unsqueeze":
-                    bn_input = producer.input[0]
-            consumer = model.find_consumer(bn_output)
-            if consumer is not None:
-                if consumer.op_type == "Squeeze":
-                    bn_output = consumer.output[0]
-            data_shape = model.get_tensor_shape(bn_input)
-            # create value_info and initializers for Mul and Add constants
-            mul_const = oh.make_tensor_value_info(
-                model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
-            )
-            graph.value_info.append(mul_const)
-            model.set_initializer(mul_const.name, A)
-            mul_output = oh.make_tensor_value_info(
-                model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape
-            )
-            graph.value_info.append(mul_output)
-            add_const = oh.make_tensor_value_info(
-                model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape
-            )
-            graph.value_info.append(add_const)
-            model.set_initializer(add_const.name, B)
-            # create Mul and Add nodes to replace the batchnorm
-            mul_node = oh.make_node(
-                "Mul", [bn_input, mul_const.name], [mul_output.name]
-            )
-            add_node = oh.make_node(
-                "Add", [mul_output.name, add_const.name], [bn_output]
-            )
-            # insert where the batchnorm is to preserve topological ordering
-            graph.node.insert(node_ind, mul_node)
-            graph.node.insert(node_ind + 1, add_node)
-            # remove old nodes
-            graph.node.remove(n)
-            if consumer is not None:
-                graph.node.remove(consumer)
-            if producer is not None:
-                graph.node.remove(producer)
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "BatchNormalization":
+                graph_modified = True
+                bn_input = n.input[0]
+                bn_output = n.output[0]
+                # extract batchnorm parameters as numpy arrays
+                scale = model.get_initializer(n.input[1])
+                bias = model.get_initializer(n.input[2])
+                mean = model.get_initializer(n.input[3])
+                variance = model.get_initializer(n.input[4])
+                epsilon = 1e-5
+                # find A and B to compute batchnorm as affine transpose Ax+B
+                # TODO is a division by moving avg factor needed for variance?
+                A = scale / np.sqrt(epsilon + variance)
+                B = bias - (A * mean)
+                # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
+                producer = model.find_producer(bn_input)
+                if producer is not None:
+                    if producer.op_type == "Unsqueeze":
+                        bn_input = producer.input[0]
+                consumer = model.find_consumer(bn_output)
+                if consumer is not None:
+                    if consumer.op_type == "Squeeze":
+                        bn_output = consumer.output[0]
+                data_shape = model.get_tensor_shape(bn_input)
+                # create value_info and initializers for Mul and Add constants
+                mul_const = oh.make_tensor_value_info(
+                    model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
+                )
+                graph.value_info.append(mul_const)
+                model.set_initializer(mul_const.name, A)
+                mul_output = oh.make_tensor_value_info(
+                    model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape
+                )
+                graph.value_info.append(mul_output)
+                add_const = oh.make_tensor_value_info(
+                    model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape
+                )
+                graph.value_info.append(add_const)
+                model.set_initializer(add_const.name, B)
+                # create Mul and Add nodes to replace the batchnorm
+                mul_node = oh.make_node(
+                    "Mul", [bn_input, mul_const.name], [mul_output.name]
+                )
+                add_node = oh.make_node(
+                    "Add", [mul_output.name, add_const.name], [bn_output]
+                )
+                # insert where the batchnorm is to preserve topological ordering
+                graph.node.insert(node_ind, mul_node)
+                graph.node.insert(node_ind + 1, add_node)
+                # remove old nodes
+                graph.node.remove(n)
+                if consumer is not None:
+                    graph.node.remove(consumer)
+                if producer is not None:
+                    graph.node.remove(producer)
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
diff --git a/src/finn/transformation/fold_constants.py b/src/finn/transformation/fold_constants.py
index a9331d3069ab81e919bd6a29f41243dc1e07ab51..5b27d906cc8ee4cbcaf7363001eb7297b5c21000 100644
--- a/src/finn/transformation/fold_constants.py
+++ b/src/finn/transformation/fold_constants.py
@@ -1,31 +1,34 @@
 import finn.core.onnx_exec as oxe
-import finn.transformation.infer_shapes as si
+from finn.transformation import Transformation
+from finn.transformation.infer_shapes import InferShapes
 
 
-def fold_constants(model):
+class FoldConstants(Transformation):
     """Replace the output of a node with const-only inputs with a precomputed
     result."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    execution_context = model.make_empty_exec_context()
-    for n in graph.node:
-        node_ind += 1
-        node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input))
-        node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits))
-        node_out = n.output[0]
-        is_all_constant_inputs = len(node_inp_dyn) == 0
-        ishape = model.get_tensor_shape(n.input[0])
-        is_const_shape = (n.op_type == "Shape") and (ishape is not None)
-        if is_all_constant_inputs or is_const_shape:
-            # this node has no dynamic inputs, only constant ones -- so we can
-            # do constant folding.
-            oxe.execute_node(n, execution_context, graph)
-            # use the execution result as an initializer
-            model.set_initializer(node_out, execution_context[node_out])
-            # remove old node
-            graph.node.remove(n)
-            graph_modified = True
-    if graph_modified:
-        model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        execution_context = model.make_empty_exec_context()
+        for n in graph.node:
+            node_ind += 1
+            node_inp_inits = list(map(lambda x: model.get_initializer(x), n.input))
+            node_inp_dyn = list(filter(lambda x: x is None, node_inp_inits))
+            node_out = n.output[0]
+            is_all_constant_inputs = len(node_inp_dyn) == 0
+            ishape = model.get_tensor_shape(n.input[0])
+            is_const_shape = (n.op_type == "Shape") and (ishape is not None)
+            if is_all_constant_inputs or is_const_shape:
+                # this node has no dynamic inputs, only constant ones -- so we can
+                # do constant folding.
+                oxe.execute_node(n, execution_context, graph)
+                # use the execution result as an initializer
+                model.set_initializer(node_out, execution_context[node_out])
+                # remove old node
+                graph.node.remove(n)
+                graph_modified = True
+        if graph_modified:
+            model = model.transform(InferShapes())
+        return (model, graph_modified)
diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index 44a4614c660c240f138ad1574cd05e292d1fcd0b..b6845312b06c409fe244ad1126d6ce37c0d85ca2 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -1,59 +1,68 @@
 import finn.core.utils as util
+from finn.transformation import Transformation
 
 
-def give_unique_node_names(model):
+class GiveUniqueNodeNames(Transformation):
     """Give unique names to each node in the graph using enumeration."""
-    optype_count = {}
-    for n in model.graph.node:
-        if n.op_type not in optype_count.keys():
-            optype_count[n.op_type] = 0
-        n.name = "%s_%d" % (n.op_type, optype_count[n.op_type])
-        optype_count[n.op_type] += 1
-    # return model_was_changed = False as single iteration is always enough
-    return (model, False)
 
+    def apply(self, model):
+        optype_count = {}
+        for n in model.graph.node:
+            if n.op_type not in optype_count.keys():
+                optype_count[n.op_type] = 0
+            n.name = "%s_%d" % (n.op_type, optype_count[n.op_type])
+            optype_count[n.op_type] += 1
+        # return model_was_changed = False as single iteration is always enough
+        return (model, False)
 
-def give_random_tensor_names(model):
+
+class GiveRandomTensorNames(Transformation):
     """Give random tensor names to all tensors."""
-    names = model.get_all_tensor_names()
-    for name in names:
-        model.rename_tensor(name, util.random_string())
-    # return model_was_changed = False as single iteration is always enough
-    return (model, False)
+
+    def apply(self, model):
+        names = model.get_all_tensor_names()
+        for name in names:
+            model.rename_tensor(name, util.random_string())
+        # return model_was_changed = False as single iteration is always enough
+        return (model, False)
 
 
-def give_readable_tensor_names(model):
+class GiveReadableTensorNames(Transformation):
     """Give more human-readable names to all internal tensors. It's recommended
     to apply give_unique_node_names prior to this transform."""
-    # to ensure we can use rename_tensor safely (without renaming existing
-    # tensors) we start by giving random names to all tensors
-    model = model.transform_single(give_random_tensor_names)
-    graph = model.graph
-    for n in graph.node:
-        out_num = 0
-        for o in n.output:
-            model.rename_tensor(o, "%s_out%d" % (n.name, out_num))
-            out_num += 1
-        init_in_num = 0
-        for i in n.input:
-            if model.get_initializer(i) is not None:
-                model.rename_tensor(i, "%s_param%d" % (n.name, init_in_num))
-                init_in_num += 1
-    # give special names to the main model input and output
-    model.rename_tensor(model.graph.input[0].name, "global_in")
-    model.rename_tensor(model.graph.output[0].name, "global_out")
-    # return model_was_changed = False as single iteration is always enough
-    return (model, False)
-
-
-def convert_sub_to_add(model):
+
+    def apply(self, model):
+        # to ensure we can use rename_tensor safely (without renaming existing
+        # tensors) we start by giving random names to all tensors
+        model = model.transform(GiveRandomTensorNames())
+        graph = model.graph
+        for n in graph.node:
+            out_num = 0
+            for o in n.output:
+                model.rename_tensor(o, "%s_out%d" % (n.name, out_num))
+                out_num += 1
+            init_in_num = 0
+            for i in n.input:
+                if model.get_initializer(i) is not None:
+                    model.rename_tensor(i, "%s_param%d" % (n.name, init_in_num))
+                    init_in_num += 1
+        # give special names to the main model input and output
+        model.rename_tensor(model.graph.input[0].name, "global_in")
+        model.rename_tensor(model.graph.output[0].name, "global_out")
+        # return model_was_changed = False as single iteration is always enough
+        return (model, False)
+
+
+class ConvertSubToAdd(Transformation):
     """Convert sub nodes to add nodes of appropriate sign."""
-    graph = model.graph
-    for n in graph.node:
-        if n.op_type == "Sub":
-            A = model.get_initializer(n.input[1])
-            if A is not None:
-                n.op_type = "Add"
-                model.set_initializer(n.input[1], -A)
-    # return model_was_changed = False as single iteration is always enough
-    return (model, False)
+
+    def apply(self, model):
+        graph = model.graph
+        for n in graph.node:
+            if n.op_type == "Sub":
+                A = model.get_initializer(n.input[1])
+                if A is not None:
+                    n.op_type = "Add"
+                    model.set_initializer(n.input[1], -A)
+        # return model_was_changed = False as single iteration is always enough
+        return (model, False)
diff --git a/src/finn/transformation/infer_datatypes.py b/src/finn/transformation/infer_datatypes.py
index 19d947b57d045d4d7f2523f0f392adaba5bb367a..a311012fc6631e76e75a37b8dc4d1b99d21ce7c7 100644
--- a/src/finn/transformation/infer_datatypes.py
+++ b/src/finn/transformation/infer_datatypes.py
@@ -1,4 +1,5 @@
 from finn.core.datatype import DataType
+from finn.transformation import Transformation
 
 
 def _infer_node_datatype(model, node):
@@ -37,11 +38,13 @@ def _infer_node_datatype(model, node):
     return graph_modified
 
 
-def infer_datatypes(model):
+class InferDataTypes(Transformation):
     """Infer FINN DataType info for all intermediate/output tensors based on
-  inputs and node type."""
-    graph = model.graph
-    graph_modified = False
-    for node in graph.node:
-        graph_modified |= _infer_node_datatype(model, node)
-    return (model, graph_modified)
+    inputs and node type."""
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        for node in graph.node:
+            graph_modified |= _infer_node_datatype(model, node)
+        return (model, graph_modified)
diff --git a/src/finn/transformation/infer_shapes.py b/src/finn/transformation/infer_shapes.py
index 063fb9b8b686a880283678e1bf779c09da209a5a..e92c6f81625a9d328bed3225a7730a943d6c9830 100644
--- a/src/finn/transformation/infer_shapes.py
+++ b/src/finn/transformation/infer_shapes.py
@@ -2,6 +2,8 @@ import onnx.helper as helper
 import onnx.shape_inference as si
 
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation import Transformation
+
 
 
 def _make_shape_compatible_op(node):
@@ -9,7 +11,7 @@ def _make_shape_compatible_op(node):
     shape inference with custom ops."""
     assert node.domain == "finn"
     if node.op_type == "MultiThreshold":
-        return helper.make_node("ReLU", [node.input[0]], [node.output[0]])
+        return helper.make_node("Relu", [node.input[0]], [node.output[0]])
     else:
         raise Exception("No known shape-compatible op for %s" % node.op_type)
 
@@ -44,12 +46,14 @@ def _restore_finn_ops(model, hidden_ops):
             pass
 
 
-def infer_shapes(model):
+class InferShapes(Transformation):
     """Ensure every tensor in the model has a specified shape (ValueInfo)."""
-    # hide your riches!
-    hidden_ops = _hide_finn_ops(model)
-    # call regular ONNX shape inference
-    model = ModelWrapper(si.infer_shapes(model.model))
-    # bring back hidden ops
-    _restore_finn_ops(model, hidden_ops)
-    return (model, False)
+
+    def apply(self, model):
+        # hide your riches!
+        hidden_ops = _hide_finn_ops(model)
+        # call regular ONNX shape inference
+        model = ModelWrapper(si.infer_shapes(model.model))
+        # bring back hidden ops
+        _restore_finn_ops(model, hidden_ops)
+        return (model, False)
diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py
index 5053b23c3e09bc5f32807cdf16a997c4647dd165..fb9e530d641063187dd75de30c8ca49936565bae 100644
--- a/src/finn/transformation/streamline.py
+++ b/src/finn/transformation/streamline.py
@@ -1,384 +1,416 @@
 import numpy as np
 from onnx import helper as oh
 
-import finn.transformation.infer_shapes as si
 from finn.core.datatype import DataType
+from finn.transformation import Transformation
+from finn.transformation.infer_shapes import InferShapes
 
 
-def convert_sign_to_thres(model):
+class ConvertSignToThres(Transformation):
     """Convert Sign node instances to MultiThreshold with threshold at 0."""
-    graph = model.graph
-    graph_modified = False
-    node_ind = 0
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Sign":
-            sign_out_name = n.output[0]
-            # find consumer
-            consumer = model.find_consumer(sign_out_name)
-            assert consumer is not None
-            # change op type and create threshold
-            n.op_type = "MultiThreshold"
-            thres_param_name = model.make_new_valueinfo_name()
-            thres_param = np.asarray([[0]], dtype=np.float32)
-            n.input.append(thres_param_name)
-            n.domain = "finn"
-            model.set_initializer(thres_param_name, thres_param)
-            # convert 0,1 -> -1,+1 with 2*x-1
-            out_shape = model.get_tensor_shape(sign_out_name)
-            # make a mul node
-            # note how set_initializer or set_tensor_shape is called before
-            # calling make_new_valueinfo_name again
-            mul_param_name = model.make_new_valueinfo_name()
-            model.set_initializer(mul_param_name, np.asarray([[2]], dtype=np.float32))
-            mul_out_name = model.make_new_valueinfo_name()
-            model.set_tensor_shape(mul_out_name, out_shape)
-            mul_node = oh.make_node(
-                "Mul", [sign_out_name, mul_param_name], [mul_out_name]
-            )
-            # make an add node
-            add_param_name = model.make_new_valueinfo_name()
-            model.set_initializer(add_param_name, np.asarray([[-1]], dtype=np.float32))
-            add_out_name = model.make_new_valueinfo_name()
-            model.set_tensor_shape(add_out_name, out_shape)
-            add_node = oh.make_node(
-                "Add", [mul_out_name, add_param_name], [add_out_name]
-            )
-            # add new nodes to graph at correct position
-            graph.node.insert(node_ind, mul_node)
-            graph.node.insert(node_ind + 1, add_node)
-            # rewrite consumer's input
-            consumer.input[0] = add_out_name
-            # add quantization annotations
-            model.set_tensor_datatype(sign_out_name, DataType.BINARY)
-            model.set_tensor_datatype(mul_out_name, DataType.UINT2)
-            model.set_tensor_datatype(add_out_name, DataType.BIPOLAR)
-            graph_modified = True
-    return (model, graph_modified)
-
-
-def collapse_repeated_op(model, op_name, make_collapsed_param_fxn):
-    """Collapse repeated consecutive operations with constant parameters into
-    a single operation. make_collapsed_param_fxn must take two tensors and
-    return a tensor which gives the equivalent result using a single op. """
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == op_name:
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == op_name:
-                op0_param_name = n.input[1]
-                op1_param_name = consumer.input[1]
-                op0_param = model.get_initializer(op0_param_name)
-                op1_param = model.get_initializer(op1_param_name)
-                assert op0_param is not None
-                assert op1_param is not None
-                start_name = n.input[0]
-                end_name = consumer.output[0]
-                # compute the new parameter
-                new_param = make_collapsed_param_fxn(op0_param, op1_param)
-                # make and insert new node
-                new_node_param_name = op0_param_name
-                new_node = oh.make_node(
-                    op_name, [start_name, new_node_param_name], [end_name]
-                )
-                graph.node.insert(node_ind, new_node)
-                # replace parameter value
-                model.set_initializer(new_node_param_name, new_param)
-                # remove old nodes
-                graph.node.remove(n)
-                graph.node.remove(consumer)
-                graph_modified = True
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
-
 
-def collapse_repeated_add(model):
-    return collapse_repeated_op(model, "Add", lambda x, y: y + x)
-
-
-def collapse_repeated_mul(model):
-    return collapse_repeated_op(model, "Mul", lambda x, y: y * x)
-
-
-def move_add_past_mul(model):
-    """Move add operations past multiply operations. The aim is to have them
-    next to each other such that they can be collapsed into a single add."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Add":
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "Mul":
-                # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
-                # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
-                # assume input 0 is from the previous layer, input 1 is the
-                # trained (constant) parameter
-                mul_weight_name = consumer.input[1]
-                add_weight_name = n.input[1]
-                A = model.get_initializer(mul_weight_name)
-                B = model.get_initializer(add_weight_name)
-                assert A is not None
-                assert B is not None
-                start_name = n.input[0]
-                middle_name = n.output[0]
-                end_name = consumer.output[0]
-                # compute new param value for add
-                BA = B * A
-                # make and insert new nodes
-                new_mul = oh.make_node(
-                    "Mul", [start_name, mul_weight_name], [middle_name]
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        node_ind = 0
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Sign":
+                sign_out_name = n.output[0]
+                # find consumer
+                consumer = model.find_consumer(sign_out_name)
+                assert consumer is not None
+                # change op type and create threshold
+                n.op_type = "MultiThreshold"
+                thres_param_name = model.make_new_valueinfo_name()
+                thres_param = np.asarray([[0]], dtype=np.float32)
+                n.input.append(thres_param_name)
+                n.domain = "finn"
+                model.set_initializer(thres_param_name, thres_param)
+                # convert 0,1 -> -1,+1 with 2*x-1
+                out_shape = model.get_tensor_shape(sign_out_name)
+                # make a mul node
+                # note how set_initializer or set_tensor_shape is called before
+                # calling make_new_valueinfo_name again
+                mul_param_name = model.make_new_valueinfo_name()
+                model.set_initializer(
+                    mul_param_name, np.asarray([[2]], dtype=np.float32)
+                )
+                mul_out_name = model.make_new_valueinfo_name()
+                model.set_tensor_shape(mul_out_name, out_shape)
+                mul_node = oh.make_node(
+                    "Mul", [sign_out_name, mul_param_name], [mul_out_name]
                 )
-                new_add = oh.make_node(
-                    "Add", [middle_name, add_weight_name], [end_name]
+                # make an add node
+                add_param_name = model.make_new_valueinfo_name()
+                model.set_initializer(
+                    add_param_name, np.asarray([[-1]], dtype=np.float32)
                 )
-                graph.node.insert(node_ind, new_mul)
-                graph.node.insert(node_ind + 1, new_add)
-                # replace add value
-                model.set_initializer(add_weight_name, BA)
-                # remove old nodes
-                graph.node.remove(n)
-                graph.node.remove(consumer)
+                add_out_name = model.make_new_valueinfo_name()
+                model.set_tensor_shape(add_out_name, out_shape)
+                add_node = oh.make_node(
+                    "Add", [mul_out_name, add_param_name], [add_out_name]
+                )
+                # add new nodes to graph at correct position
+                graph.node.insert(node_ind, mul_node)
+                graph.node.insert(node_ind + 1, add_node)
+                # rewrite consumer's input
+                consumer.input[0] = add_out_name
+                # add quantization annotations
+                model.set_tensor_datatype(sign_out_name, DataType.BINARY)
+                model.set_tensor_datatype(mul_out_name, DataType.UINT2)
+                model.set_tensor_datatype(add_out_name, DataType.BIPOLAR)
                 graph_modified = True
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+        return (model, graph_modified)
 
 
-def move_scalar_mul_past_matmul(model):
-    """Move scalar mul operations past matmul operations. We want to have muls
-    next to each other such that they can be collapsed into a single mul."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Mul":
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "MatMul":
-                mul_weight_name = n.input[1]
-                matmul_weight_name = consumer.input[1]
-                A = model.get_initializer(mul_weight_name)
-                W = model.get_initializer(matmul_weight_name)
-                assert A is not None
-                assert W is not None
-                start_name = n.input[0]
-                middle_name = n.output[0]
-                end_name = consumer.output[0]
-                mm_out_shape = model.get_tensor_shape(end_name)
-                if all(x == 1 for x in A.shape):
-                    # if the mul is scalar, we can simply swap the order of ops
-                    # make and insert new nodes
-                    new_matmul = oh.make_node(
-                        "MatMul", [start_name, matmul_weight_name], [middle_name]
-                    )
-                    new_mul = oh.make_node(
-                        "Mul", [middle_name, mul_weight_name], [end_name]
+class CollapseRepeatedOp(Transformation):
+    """Collapse repeated consecutive operations with constant parameters into
+    a single operation. make_collapsed_param_fxn must take two tensors and
+    return a tensor which gives the equivalent result using a single op. """
+
+    def __init__(self, op_name, make_collapsed_param_fxn):
+        super().__init__()
+        self.op_name = op_name
+        self.make_collapsed_param_fxn = make_collapsed_param_fxn
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == self.op_name:
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == self.op_name:
+                    op0_param_name = n.input[1]
+                    op1_param_name = consumer.input[1]
+                    op0_param = model.get_initializer(op0_param_name)
+                    op1_param = model.get_initializer(op1_param_name)
+                    assert op0_param is not None
+                    assert op1_param is not None
+                    start_name = n.input[0]
+                    end_name = consumer.output[0]
+                    # compute the new parameter
+                    new_param = self.make_collapsed_param_fxn(op0_param, op1_param)
+                    # make and insert new node
+                    new_node_param_name = op0_param_name
+                    new_node = oh.make_node(
+                        self.op_name, [start_name, new_node_param_name], [end_name]
                     )
-                    graph.node.insert(node_ind, new_matmul)
-                    graph.node.insert(node_ind + 1, new_mul)
-                    model.set_tensor_shape(middle_name, mm_out_shape)
+                    graph.node.insert(node_ind, new_node)
+                    # replace parameter value
+                    model.set_initializer(new_node_param_name, new_param)
                     # remove old nodes
                     graph.node.remove(n)
                     graph.node.remove(consumer)
                     graph_modified = True
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
 
 
-def move_scalar_add_past_matmul(model):
-    """Move scalar add operations past matmul operations. We want to have adds
+class CollapseRepeatedAdd(CollapseRepeatedOp):
+    def __init__(self):
+        super().__init__("Add", lambda x, y: y + x)
+
+
+class CollapseRepeatedMul(CollapseRepeatedOp):
+    def __init__(self):
+        super().__init__("Mul", lambda x, y: y * x)
+
+
+class MoveAddPastMul(Transformation):
+    """Move add operations past multiply operations. The aim is to have them
     next to each other such that they can be collapsed into a single add."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Add":
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "MatMul":
-                add_weight_name = n.input[1]
-                matmul_weight_name = consumer.input[1]
-                A = model.get_initializer(add_weight_name)
-                W = model.get_initializer(matmul_weight_name)
-                assert A is not None
-                assert W is not None
-                start_name = n.input[0]
-                middle_name = n.output[0]
-                end_name = consumer.output[0]
-                mm_out_shape = model.get_tensor_shape(end_name)
-                if all(x == 1 for x in A.shape):
-                    # if the add is scalar, we can move it past the matmul
-                    # by taking it past the matmul with a dot product
-                    Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W)
-                    # update the add weight
-                    model.set_initializer(add_weight_name, Anew)
-                    new_matmul = oh.make_node(
-                        "MatMul", [start_name, matmul_weight_name], [middle_name]
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Add":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "Mul":
+                    # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA)
+                    # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA)
+                    # assume input 0 is from the previous layer, input 1 is the
+                    # trained (constant) parameter
+                    mul_weight_name = consumer.input[1]
+                    add_weight_name = n.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    B = model.get_initializer(add_weight_name)
+                    assert A is not None
+                    assert B is not None
+                    start_name = n.input[0]
+                    middle_name = n.output[0]
+                    end_name = consumer.output[0]
+                    # compute new param value for add
+                    BA = B * A
+                    # make and insert new nodes
+                    new_mul = oh.make_node(
+                        "Mul", [start_name, mul_weight_name], [middle_name]
                     )
                     new_add = oh.make_node(
                         "Add", [middle_name, add_weight_name], [end_name]
                     )
-                    graph.node.insert(node_ind, new_matmul)
+                    graph.node.insert(node_ind, new_mul)
                     graph.node.insert(node_ind + 1, new_add)
-                    model.set_tensor_shape(middle_name, mm_out_shape)
+                    # replace add value
+                    model.set_initializer(add_weight_name, BA)
                     # remove old nodes
                     graph.node.remove(n)
                     graph.node.remove(consumer)
                     graph_modified = True
-    model = model.transform_single(si.infer_shapes)
-    return (model, graph_modified)
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
 
 
-def absorb_add_into_multi_threshold(model):
+class MoveScalarMulPastMatMul(Transformation):
+    """Move scalar mul operations past matmul operations. We want to have muls
+    next to each other such that they can be collapsed into a single mul."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Mul":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MatMul":
+                    mul_weight_name = n.input[1]
+                    matmul_weight_name = consumer.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    W = model.get_initializer(matmul_weight_name)
+                    assert A is not None
+                    assert W is not None
+                    start_name = n.input[0]
+                    middle_name = n.output[0]
+                    end_name = consumer.output[0]
+                    mm_out_shape = model.get_tensor_shape(end_name)
+                    if all(x == 1 for x in A.shape):
+                        # if the mul is scalar, we can simply swap the order of ops
+                        # make and insert new nodes
+                        new_matmul = oh.make_node(
+                            "MatMul", [start_name, matmul_weight_name], [middle_name]
+                        )
+                        new_mul = oh.make_node(
+                            "Mul", [middle_name, mul_weight_name], [end_name]
+                        )
+                        graph.node.insert(node_ind, new_matmul)
+                        graph.node.insert(node_ind + 1, new_mul)
+                        model.set_tensor_shape(middle_name, mm_out_shape)
+                        # remove old nodes
+                        graph.node.remove(n)
+                        graph.node.remove(consumer)
+                        graph_modified = True
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
+class MoveScalarAddPastMatMul(Transformation):
+    """Move scalar add operations past matmul operations. We want to have adds
+    next to each other such that they can be collapsed into a single add."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Add":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MatMul":
+                    add_weight_name = n.input[1]
+                    matmul_weight_name = consumer.input[1]
+                    A = model.get_initializer(add_weight_name)
+                    W = model.get_initializer(matmul_weight_name)
+                    assert A is not None
+                    assert W is not None
+                    start_name = n.input[0]
+                    middle_name = n.output[0]
+                    end_name = consumer.output[0]
+                    mm_out_shape = model.get_tensor_shape(end_name)
+                    if all(x == 1 for x in A.shape):
+                        # if the add is scalar, we can move it past the matmul
+                        # by taking it past the matmul with a dot product
+                        Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W)
+                        # update the add weight
+                        model.set_initializer(add_weight_name, Anew)
+                        new_matmul = oh.make_node(
+                            "MatMul", [start_name, matmul_weight_name], [middle_name]
+                        )
+                        new_add = oh.make_node(
+                            "Add", [middle_name, add_weight_name], [end_name]
+                        )
+                        graph.node.insert(node_ind, new_matmul)
+                        graph.node.insert(node_ind + 1, new_add)
+                        model.set_tensor_shape(middle_name, mm_out_shape)
+                        # remove old nodes
+                        graph.node.remove(n)
+                        graph.node.remove(consumer)
+                        graph_modified = True
+        model = model.transform(InferShapes())
+        return (model, graph_modified)
+
+
+class AbsorbAddIntoMultiThreshold(Transformation):
     """Absorb preceding Add ops into MultiThreshold by updating the threshold
     values."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Add":
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "MultiThreshold":
-                add_weight_name = n.input[1]
-                threshold_name = consumer.input[1]
-                A = model.get_initializer(add_weight_name)
-                T = model.get_initializer(threshold_name)
-                assert A is not None
-                assert T is not None
-                start_name = n.input[0]
-                # compute new thresholds and set initializer
-                Tnew = T - A.reshape(-1, T.shape[1])
-                model.set_initializer(threshold_name, Tnew)
-                # wire add input directly to MultiThreshold
-                consumer.input[0] = start_name
-                # remove the add node
-                graph.node.remove(n)
-                graph_modified = True
-    return (model, graph_modified)
 
-
-def absorb_mul_into_multi_threshold(model):
-    """Absorb preceding Mul ops into MultiThreshold by updating the threshold
-    values. Only *positive* scalar/1D vectors can be absorbed."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Mul":
-            mul_weight_name = n.input[1]
-            A = model.get_initializer(mul_weight_name)
-            assert A is not None
-            is_signed = (A < 0).any()
-            is_scalar = np.prod(A.shape) == 1
-            is_1d = len(A.shape) == 2 and A.shape[0] == 1
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "MultiThreshold":
-                if not is_signed and (is_1d or is_scalar):
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Add":
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MultiThreshold":
+                    add_weight_name = n.input[1]
                     threshold_name = consumer.input[1]
+                    A = model.get_initializer(add_weight_name)
                     T = model.get_initializer(threshold_name)
+                    assert A is not None
                     assert T is not None
                     start_name = n.input[0]
                     # compute new thresholds and set initializer
-                    Tnew = T / A.reshape(-1, T.shape[1])
-                    # TODO: need to handle negative A values correctly; produce
-                    # mul sign mask and merge into preceding matmul?
+                    Tnew = T - A.reshape(-1, T.shape[1])
                     model.set_initializer(threshold_name, Tnew)
                     # wire add input directly to MultiThreshold
                     consumer.input[0] = start_name
-                    # remove the mul node
+                    # remove the add node
                     graph.node.remove(n)
                     graph_modified = True
-    return (model, graph_modified)
+        return (model, graph_modified)
+
+
+class AbsorbMulIntoMultiThreshold(Transformation):
+    """Absorb preceding Mul ops into MultiThreshold by updating the threshold
+    values. Only *positive* scalar/1D vectors can be absorbed."""
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Mul":
+                mul_weight_name = n.input[1]
+                A = model.get_initializer(mul_weight_name)
+                assert A is not None
+                is_signed = (A < 0).any()
+                is_scalar = np.prod(A.shape) == 1
+                is_1d = len(A.shape) == 2 and A.shape[0] == 1
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "MultiThreshold":
+                    if not is_signed and (is_1d or is_scalar):
+                        threshold_name = consumer.input[1]
+                        T = model.get_initializer(threshold_name)
+                        assert T is not None
+                        start_name = n.input[0]
+                        # compute new thresholds and set initializer
+                        Tnew = T / A.reshape(-1, T.shape[1])
+                        # TODO: need to handle negative A values correctly; produce
+                        # mul sign mask and merge into preceding matmul?
+                        model.set_initializer(threshold_name, Tnew)
+                        # wire add input directly to MultiThreshold
+                        consumer.input[0] = start_name
+                        # remove the mul node
+                        graph.node.remove(n)
+                        graph_modified = True
+        return (model, graph_modified)
 
 
-def factor_out_mul_sign_magnitude(model):
+class FactorOutMulSignMagnitude(Transformation):
     """Split multiply-by-constant nodes into two multiply-by-constant nodes,
     where the first node is a bipolar vector (of signs) and the second is a
     vector of magnitudes."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "Mul":
-            mul_weight_name = n.input[1]
-            A = model.get_initializer(mul_weight_name)
-            assert A is not None
-            is_scalar = np.prod(A.shape) == 1
-            is_1d = len(A.shape) == 2 and A.shape[0] == 1
-            is_not_bipolar = (
-                model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR
-            )
-            is_signed = (A < 0).any()
-            if is_signed and (is_scalar or is_1d) and is_not_bipolar:
-                start_name = n.input[0]
-                in_shape = model.get_tensor_shape(start_name)
-                middle_name = model.make_new_valueinfo_name()
-                model.set_tensor_shape(middle_name, in_shape)
-                sign_mul_param_name = model.make_new_valueinfo_name()
-                # create new mul node with sign(A) as the operand
-                sgn = np.sign(A)
-                model.set_initializer(sign_mul_param_name, sgn)
-                model.set_tensor_datatype(sign_mul_param_name, DataType.BIPOLAR)
-                # replace original mul weight by magnitudes
-                model.set_initializer(mul_weight_name, np.abs(A))
-                new_mul = oh.make_node(
-                    "Mul", [start_name, sign_mul_param_name], [middle_name]
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "Mul":
+                mul_weight_name = n.input[1]
+                A = model.get_initializer(mul_weight_name)
+                assert A is not None
+                is_scalar = np.prod(A.shape) == 1
+                is_1d = len(A.shape) == 2 and A.shape[0] == 1
+                is_not_bipolar = (
+                    model.get_tensor_datatype(mul_weight_name) != DataType.BIPOLAR
                 )
-                n.input[0] = middle_name
-                graph.node.insert(node_ind - 1, new_mul)
-                graph_modified = True
-    return (model, graph_modified)
+                is_signed = (A < 0).any()
+                if is_signed and (is_scalar or is_1d) and is_not_bipolar:
+                    start_name = n.input[0]
+                    in_shape = model.get_tensor_shape(start_name)
+                    middle_name = model.make_new_valueinfo_name()
+                    model.set_tensor_shape(middle_name, in_shape)
+                    sign_mul_param_name = model.make_new_valueinfo_name()
+                    # create new mul node with sign(A) as the operand
+                    sgn = np.sign(A)
+                    model.set_initializer(sign_mul_param_name, sgn)
+                    model.set_tensor_datatype(sign_mul_param_name, DataType.BIPOLAR)
+                    # replace original mul weight by magnitudes
+                    model.set_initializer(mul_weight_name, np.abs(A))
+                    new_mul = oh.make_node(
+                        "Mul", [start_name, sign_mul_param_name], [middle_name]
+                    )
+                    n.input[0] = middle_name
+                    graph.node.insert(node_ind - 1, new_mul)
+                    graph_modified = True
+        return (model, graph_modified)
 
 
-def absorb_1bit_mul_into_matmul(model):
+class Absorb1BitMulIntoMatMul(Transformation):
     """Absorb bipolar or binary multiplications into the preciding matrix
     multiply."""
-    graph = model.graph
-    node_ind = 0
-    graph_modified = False
-    for n in graph.node:
-        node_ind += 1
-        if n.op_type == "MatMul":
-            matmul_weight_name = n.input[1]
-            W = model.get_initializer(matmul_weight_name)
-            assert W is not None
-            consumer = model.find_consumer(n.output[0])
-            if consumer is not None and consumer.op_type == "Mul":
-                mul_weight_name = consumer.input[1]
-                A = model.get_initializer(mul_weight_name)
-                assert A is not None
-                is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
-                if is_1bit:
-                    Wnew = A * W
-                    assert Wnew.shape == W.shape
-                    model.set_initializer(matmul_weight_name, Wnew)
-                    n.output[0] = consumer.output[0]
-                    graph.node.remove(consumer)
-                    graph_modified = True
-    return (model, graph_modified)
+
+    def apply(self, model):
+        graph = model.graph
+        node_ind = 0
+        graph_modified = False
+        for n in graph.node:
+            node_ind += 1
+            if n.op_type == "MatMul":
+                matmul_weight_name = n.input[1]
+                W = model.get_initializer(matmul_weight_name)
+                assert W is not None
+                consumer = model.find_consumer(n.output[0])
+                if consumer is not None and consumer.op_type == "Mul":
+                    mul_weight_name = consumer.input[1]
+                    A = model.get_initializer(mul_weight_name)
+                    assert A is not None
+                    is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
+                    if is_1bit:
+                        Wnew = A * W
+                        assert Wnew.shape == W.shape
+                        model.set_initializer(matmul_weight_name, Wnew)
+                        n.output[0] = consumer.output[0]
+                        graph.node.remove(consumer)
+                        graph_modified = True
+        return (model, graph_modified)
 
 
-def round_thresholds(model):
+class RoundThresholds(Transformation):
     """For MultiThreshold nodes operating on integer inputs, round up
     thresholds values to the nearest integer."""
-    graph = model.graph
-    graph_modified = False
-    for n in graph.node:
-        if n.op_type == "MultiThreshold":
-            idtype = model.get_tensor_datatype(n.input[0])
-            T = model.get_initializer(n.input[1])
-            Tnew = np.ceil(T)
-            if idtype.is_integer() and (T != Tnew).any():
-                # round up the thresholds to nearest integer
-                model.set_initializer(n.input[1], Tnew)
-                # use same datatype as inputs for thresholds
-                model.set_tensor_datatype(n.input[1], idtype)
-                graph_modified = True
-    return (model, graph_modified)
+
+    def apply(self, model):
+        graph = model.graph
+        graph_modified = False
+        for n in graph.node:
+            if n.op_type == "MultiThreshold":
+                idtype = model.get_tensor_datatype(n.input[0])
+                T = model.get_initializer(n.input[1])
+                Tnew = np.ceil(T)
+                if idtype.is_integer() and (T != Tnew).any():
+                    # round up the thresholds to nearest integer
+                    model.set_initializer(n.input[1], Tnew)
+                    # use same datatype as inputs for thresholds
+                    model.set_tensor_datatype(n.input[1], idtype)
+                    graph_modified = True
+        return (model, graph_modified)
diff --git a/tests/test_basic_onnx_exec.py b/tests/test_basic_onnx_exec.py
index 30e9106febcbf9a84995ceb1d03f4c3782291d8d..c7b3da1b78385d36fc73790b22336242141d5255 100644
--- a/tests/test_basic_onnx_exec.py
+++ b/tests/test_basic_onnx_exec.py
@@ -5,15 +5,15 @@ import onnx
 import onnx.numpy_helper as np_helper
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_mnist_onnx_download_extract_run():
     # load the onnx model
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb")
diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py
index bb66f98f490069ffc12c4503a5a897c37bcf93b8..dec01b53a9081c6c04143e4465d603d66fc771a6 100644
--- a/tests/test_batchnorm_to_affine.py
+++ b/tests/test_batchnorm_to_affine.py
@@ -8,10 +8,10 @@ import torch
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.batchnorm_to_affine as tx
-import finn.transformation.fold_constants as fc
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.batchnorm_to_affine import BatchNormToAffine
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 transformed_onnx_path = "test_output_lfc_transformed.onnx"
@@ -27,9 +27,9 @@ def test_batchnorm_to_affine():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
-    new_model = model.transform_single(tx.batchnorm_to_affine)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    new_model = model.transform(BatchNormToAffine())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
diff --git a/tests/test_brevitas_export.py b/tests/test_brevitas_export.py
index 80850edb72af4f37e0e626df5b26513450d0c9d3..641e5e3c49b917ad06d649d797e6f9bc9f30f170 100644
--- a/tests/test_brevitas_export.py
+++ b/tests/test_brevitas_export.py
@@ -9,21 +9,25 @@ import torch
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.fold_constants as fc
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 # TODO get from config instead, hardcoded to Docker path for now
-trained_lfc_checkpoint = (
+trained_lfc_w1a1_checkpoint = (
     "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar"
 )
 
+trained_lfc_w1a2_checkpoint = (
+    "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W2A/checkpoints/best.tar"
+)
+
 
-def test_brevitas_trained_lfc_pytorch():
+def test_brevitas_trained_lfc_w1a1_pytorch():
     # load pretrained weights into LFC-w1a1
     lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1).eval()
-    checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
+    checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu")
     lfc.load_state_dict(checkpoint["state_dict"])
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
@@ -49,14 +53,68 @@ def test_brevitas_trained_lfc_pytorch():
     assert np.isclose(produced, expected, atol=1e-4).all()
 
 
-def test_brevitas_to_onnx_export_and_exec():
+def test_brevitas_trained_lfc_w1a2_pytorch():
+    # load pretrained weights into LFC-w1a2
+    lfc = LFC(weight_bit_width=1, act_bit_width=2, in_bit_width=2).eval()
+    checkpoint = torch.load(trained_lfc_w1a2_checkpoint, map_location="cpu")
+    lfc.load_state_dict(checkpoint["state_dict"])
+    # load one of the test vectors
+    raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
+    input_tensor = onnx.load_tensor_from_string(raw_i)
+    input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
+    assert input_tensor.shape == (1, 1, 28, 28)
+    # do forward pass in PyTorch/Brevitas
+    produced = lfc.forward(input_tensor).detach().numpy()
+    expected = [
+        [
+            4.598069,
+            -6.3698025,
+            10.75695,
+            0.3796571,
+            1.4764442,
+            -5.4417515,
+            -1.8982856,
+            -5.610488,
+            6.116698,
+            0.21092065,
+        ]
+    ]
+    assert np.isclose(produced, expected, atol=1e-4).all()
+
+
+def test_brevitas_to_onnx_export_and_exec_lfc_w1a1():
     lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
-    checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
+    checkpoint = torch.load(trained_lfc_w1a1_checkpoint, map_location="cpu")
+    lfc.load_state_dict(checkpoint["state_dict"])
+    bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
+    model = ModelWrapper(export_onnx_path)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    # load one of the test vectors
+    raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
+    input_tensor = onnx.load_tensor_from_string(raw_i)
+    # run using FINN-based execution
+    input_dict = {"0": nph.to_array(input_tensor)}
+    output_dict = oxe.execute_onnx(model, input_dict)
+    produced = output_dict[list(output_dict.keys())[0]]
+    # run using PyTorch/Brevitas
+    input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
+    assert input_tensor.shape == (1, 1, 28, 28)
+    # do forward pass in PyTorch/Brevitas
+    expected = lfc.forward(input_tensor).detach().numpy()
+    assert np.isclose(produced, expected, atol=1e-3).all()
+    # remove the downloaded model and extracted files
+    os.remove(export_onnx_path)
+
+
+def test_brevitas_to_onnx_export_and_exec_lfc_w1a2():
+    lfc = LFC(weight_bit_width=1, act_bit_width=2, in_bit_width=2)
+    checkpoint = torch.load(trained_lfc_w1a2_checkpoint, map_location="cpu")
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
diff --git a/tests/test_collapse_repeated_op.py b/tests/test_collapse_repeated_op.py
index d8cdde3c654873b368653347863af05317bc5bcb..d97cdbc3033000b94e499988155b3af829040109 100644
--- a/tests/test_collapse_repeated_op.py
+++ b/tests/test_collapse_repeated_op.py
@@ -3,9 +3,9 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.core.onnx_exec as ox
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as tx
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import CollapseRepeatedAdd, CollapseRepeatedMul
 
 
 def test_collapse_repeated_op():
@@ -30,12 +30,12 @@ def test_collapse_repeated_op():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("add_param_0", np.asarray([1, 3], dtype=np.float32))
     model.set_initializer("add_param_1", np.asarray([-1, 3], dtype=np.float32))
     model.set_initializer("mul_param_0", np.asarray([2, 4], dtype=np.float32))
     model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32))
-    new_model = model.transform_repeated(tx.collapse_repeated_add)
-    new_model = new_model.transform_repeated(tx.collapse_repeated_mul)
+    new_model = model.transform(CollapseRepeatedAdd())
+    new_model = new_model.transform(CollapseRepeatedMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
diff --git a/tests/test_factor_out_mul_sign_magnitude.py b/tests/test_factor_out_mul_sign_magnitude.py
index 2b4c0ba936366486a74b4c7fc8924e4d1aa67dda..3786492fefff046b098bce06cb2e624723b098ec 100644
--- a/tests/test_factor_out_mul_sign_magnitude.py
+++ b/tests/test_factor_out_mul_sign_magnitude.py
@@ -3,9 +3,9 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.core.onnx_exec as ox
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as tx
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import FactorOutMulSignMagnitude
 
 
 def test_factor_out_mul_sign_magnitude():
@@ -22,8 +22,8 @@ def test_factor_out_mul_sign_magnitude():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("mul_param", np.asarray([[-1, 4]], dtype=np.float32))
-    new_model = model.transform_repeated(tx.factor_out_mul_sign_magnitude)
+    new_model = model.transform(FactorOutMulSignMagnitude())
     inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
diff --git a/tests/test_fold_constants.py b/tests/test_fold_constants.py
index 894887b0c91420a5755f2e737137f2fa47862b0b..09dbd95c27cec65183ec7dd6067ce187595fcf52 100644
--- a/tests/test_fold_constants.py
+++ b/tests/test_fold_constants.py
@@ -8,9 +8,9 @@ import onnx.numpy_helper as np_helper
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.fold_constants as fc
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 
@@ -18,8 +18,8 @@ export_onnx_path = "test_output_lfc.onnx"
 def test_const_folding():
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_single(fc.fold_constants)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     raw_o = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/output_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
@@ -35,8 +35,8 @@ def test_const_folding_shapes():
     lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
     assert model.graph.node[0].op_type == "Reshape"
     assert list(model.get_tensor_shape("0")) == [1, 1, 28, 28]
     assert list(model.get_tensor_shape("27")) == [1, 784]
diff --git a/tests/test_general_transformation.py b/tests/test_general_transformation.py
index d4376cca6051bd45669c5bb1a6102dc7db76e3b4..dd44502add93d269b8d48ed951620b6a36f9fb1b 100644
--- a/tests/test_general_transformation.py
+++ b/tests/test_general_transformation.py
@@ -1,13 +1,13 @@
 from pkgutil import get_data
 
-import finn.transformation.general as tg
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.general import GiveUniqueNodeNames
 
 
 def test_give_unique_node_names():
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(tg.give_unique_node_names)
+    model = model.transform(GiveUniqueNodeNames())
     assert model.graph.node[0].name == "Reshape_0"
     assert model.graph.node[1].name == "Conv_0"
     assert model.graph.node[11].name == "Add_2"
diff --git a/tests/test_infer_datatypes.py b/tests/test_infer_datatypes.py
index fd1c99a1da5f5bdb93f0b8b089fddb20e3c078ed..e4269499c55eb89d9a9c268f79456ca6ac588028 100644
--- a/tests/test_infer_datatypes.py
+++ b/tests/test_infer_datatypes.py
@@ -4,12 +4,12 @@ import brevitas.onnx as bo
 import torch
 from models.LFC import LFC
 
-import finn.transformation.fold_constants as fc
-import finn.transformation.general as tg
-import finn.transformation.infer_datatypes as id
-import finn.transformation.infer_shapes as si
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 # TODO get from config instead, hardcoded to Docker path for now
@@ -24,11 +24,11 @@ def test_infer_datatypes():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
-    model = model.transform_single(tg.give_unique_node_names)
-    model = model.transform_single(tg.give_readable_tensor_names)
-    model = model.transform_repeated(id.infer_datatypes)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
+    model = model.transform(InferDataTypes())
     assert model.get_tensor_datatype("MatMul_0_out0") == DataType.INT32
     assert model.get_tensor_datatype("MatMul_1_out0") == DataType.INT32
     assert model.get_tensor_datatype("MatMul_2_out0") == DataType.INT32
diff --git a/tests/test_infer_shapes.py b/tests/test_infer_shapes.py
index f4c27572cf30decce49f93fb32f12868440b00e2..20841b32275968ed842fdbbebffa7168b61b7e06 100644
--- a/tests/test_infer_shapes.py
+++ b/tests/test_infer_shapes.py
@@ -3,8 +3,9 @@ from pkgutil import get_data
 import numpy as np
 from onnx import TensorProto, helper
 
-import finn.transformation.infer_shapes as si
+import finn.core.utils as util
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_infer_shapes():
@@ -26,7 +27,7 @@ def test_infer_shapes():
     # thresholds for one channel have to be sorted to guarantee the correct behavior
     mt_thresh0_values = np.empty([8, 7], dtype=np.float32)
     for i in range(len(mt_thresh0_values)):
-        mt_thresh0_values[i] = np.sort(np.random.random_sample(7,) * 10)
+        mt_thresh0_values[i] = np.sort(np.random.random_sample(7) * 10)
 
     model.set_initializer(mt_thresh0.name, mt_thresh0_values)
 
@@ -36,6 +37,9 @@ def test_infer_shapes():
     )
     Relu_node.output[0] = "mt_v0"
 
+    # explicitly remove any present shape from ReLU and MultiThreshold outputs
+    util.remove_by_name(model.graph.value_info, Relu_node.output[0])
+    util.remove_by_name(model.graph.value_info, mt_node.output[0])
     graph.node.insert(4, mt_node)
 
     # first check routine
@@ -45,7 +49,7 @@ def test_infer_shapes():
     ), "All tensors are already specified before the shape inference execution"
 
     # perform shape inference on mixed model
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
 
     # second check routine
     # now all shapes should be specified and mt_node output shape is (1,8,28,28)
diff --git a/tests/test_is_linear.py b/tests/test_is_linear.py
index 1995604b7b0d1b816dea17500486c9ece1ac04c6..cce4612a93112608564b59d04c0370b195c5f2d4 100644
--- a/tests/test_is_linear.py
+++ b/tests/test_is_linear.py
@@ -2,8 +2,8 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.analysis.topology as ta
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_is_linear_linear():
@@ -24,7 +24,7 @@ def test_is_linear_linear():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.is_linear)
     assert ret["is_linear"] is True
 
@@ -52,6 +52,6 @@ def test_is_linear_forked_node_output():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.is_linear)
     assert ret["is_linear"] is False
diff --git a/tests/test_mixed_onnx_exec.py b/tests/test_mixed_onnx_exec.py
index 0ac3f09da2f2b9fc27d88b3c90bd178adaeaf25b..75170da3ad3a8c9b24814e171b8ccdfde4fb74cd 100644
--- a/tests/test_mixed_onnx_exec.py
+++ b/tests/test_mixed_onnx_exec.py
@@ -2,8 +2,8 @@ import numpy as np
 from onnx import TensorProto, helper
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_execute_mixed_model():
@@ -30,7 +30,7 @@ def test_execute_mixed_model():
     model_def = helper.make_model(graph_def, producer_name="onnx-example")
 
     model = ModelWrapper(model_def)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
 
     inputs = np.asarray(
         [
diff --git a/tests/test_move_add_past_mul.py b/tests/test_move_add_past_mul.py
index a23827d0e91a543cf5595c816628169a251bf4ce..565bbce39b83d94352d2cd69d01d8270675bbe1f 100644
--- a/tests/test_move_add_past_mul.py
+++ b/tests/test_move_add_past_mul.py
@@ -3,9 +3,9 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.core.onnx_exec as ox
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as tx
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import MoveAddPastMul
 
 
 def test_move_add_past_mul_single():
@@ -26,10 +26,10 @@ def test_move_add_past_mul_single():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("add_param", np.asarray([1, 3], dtype=np.float32))
     model.set_initializer("mul_param", np.asarray([2, 4], dtype=np.float32))
-    new_model = model.transform_repeated(tx.move_add_past_mul)
+    new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
 
@@ -56,11 +56,11 @@ def test_move_add_past_mul_multi():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("add_param_0", np.asarray([1, 3], dtype=np.float32))
     model.set_initializer("mul_param_0", np.asarray([2, 4], dtype=np.float32))
     model.set_initializer("add_param_1", np.asarray([-1, 3], dtype=np.float32))
     model.set_initializer("mul_param_1", np.asarray([2, -4], dtype=np.float32))
-    new_model = model.transform_repeated(tx.move_add_past_mul)
+    new_model = model.transform(MoveAddPastMul())
     inp_dict = {"top_in": np.asarray([-1.0, 1.0], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
diff --git a/tests/test_move_scalar_past_matmul.py b/tests/test_move_scalar_past_matmul.py
index 7bbdd7dd8506437371147faa98310b8caa115318..c2771ce94d002ab62d33226771965140f8614ec1 100644
--- a/tests/test_move_scalar_past_matmul.py
+++ b/tests/test_move_scalar_past_matmul.py
@@ -3,9 +3,12 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.core.onnx_exec as ox
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as tx
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import (
+    MoveScalarAddPastMatMul,
+    MoveScalarMulPastMatMul
+)
 
 
 def test_move_scalar_mul_past_matmul():
@@ -26,12 +29,12 @@ def test_move_scalar_mul_past_matmul():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("mul_param", np.asarray([[3]], dtype=np.float32))
     model.set_initializer(
         "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32)
     )
-    new_model = model.transform_repeated(tx.move_scalar_mul_past_matmul)
+    new_model = model.transform(MoveScalarMulPastMatMul())
     inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
     assert new_model.graph.node[0].op_type == "MatMul"
@@ -57,12 +60,12 @@ def test_move_scalar_add_past_matmul():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     model.set_initializer("add_param", np.asarray([[3]], dtype=np.float32))
     model.set_initializer(
         "matmul_param", np.asarray([[2, 4], [-1, 1]], dtype=np.float32)
     )
-    new_model = model.transform_repeated(tx.move_scalar_add_past_matmul)
+    new_model = model.transform(MoveScalarAddPastMatMul())
     inp_dict = {"top_in": np.asarray([[-1.0, 1.0]], dtype=np.float32)}
     assert ox.compare_execution(model, new_model, inp_dict)
     assert new_model.graph.node[0].op_type == "MatMul"
diff --git a/tests/test_renaming.py b/tests/test_renaming.py
index c13f7ee066e514064a2943493d810f48aa8f97be..aec1c8a10768f293ff9aaf44d7418c777837010c 100644
--- a/tests/test_renaming.py
+++ b/tests/test_renaming.py
@@ -5,18 +5,18 @@ import onnx
 import onnx.numpy_helper as np_helper
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.general as tg
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_renaming():
     # load the onnx model
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_single(tg.give_unique_node_names)
-    model = model.transform_single(tg.give_readable_tensor_names)
+    model = model.transform(InferShapes())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
     # do some basic checks
     assert model.graph.input[0].name == "global_in"
     assert model.graph.output[0].name == "global_out"
@@ -27,8 +27,8 @@ def test_renaming():
     assert model.graph.node[6].name == "Add_1"
     assert model.graph.node[6].input[1] == "Add_1_param0"
     # ensure running renaming twice still yields the same names
-    model = model.transform_single(tg.give_unique_node_names)
-    model = model.transform_single(tg.give_readable_tensor_names)
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
     assert model.graph.node[1].op_type == "Conv"
     assert model.graph.node[1].name == "Conv_0"
     assert model.graph.node[1].input[1] == "Conv_0_param0"
diff --git a/tests/test_round_thresholds.py b/tests/test_round_thresholds.py
index 28add7313201f1e30ecaf98d841f762a3f859835..afd1d86f868ed6b1343c3f6a1265a9016c3f8cef 100644
--- a/tests/test_round_thresholds.py
+++ b/tests/test_round_thresholds.py
@@ -2,9 +2,9 @@ import numpy as np
 from onnx import TensorProto, helper
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.streamline as sl
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.streamline import RoundThresholds
 
 
 def test_round_thresholds():
@@ -27,7 +27,7 @@ def test_round_thresholds():
     orig_n = oxe.execute_onnx(model, inp_dict_n)["out"]
     orig_c = oxe.execute_onnx(model, inp_dict_c)["out"]
     assert model.get_tensor_datatype("thresholds") == DataType.FLOAT32
-    new_model = model.transform_repeated(sl.round_thresholds)
+    new_model = model.transform(RoundThresholds())
     # rounded up thresholds should have same dtype as input
     assert new_model.get_tensor_datatype("thresholds") == DataType.INT8
     new_f = oxe.execute_onnx(new_model, inp_dict_f)["out"]
diff --git a/tests/test_sign_to_thres.py b/tests/test_sign_to_thres.py
index 4e18822cf9beceab009678876c1ca20f581ec22a..75327df3e5194f48f71c914bc62fc5a08588faff 100644
--- a/tests/test_sign_to_thres.py
+++ b/tests/test_sign_to_thres.py
@@ -8,10 +8,10 @@ import torch
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.fold_constants as fc
-import finn.transformation.infer_shapes as si
-import finn.transformation.streamline as sl
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.streamline import ConvertSignToThres
 
 export_onnx_path = "test_output_lfc.onnx"
 transformed_onnx_path = "test_output_lfc_transformed.onnx"
@@ -27,9 +27,9 @@ def test_sign_to_thres():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
-    new_model = model.transform_single(sl.convert_sign_to_thres)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    new_model = model.transform(ConvertSignToThres())
     assert new_model.graph.node[3].op_type == "MultiThreshold"
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
diff --git a/tests/test_streamline.py b/tests/test_streamline.py
index 1a93c6ff128b2169da4802d7367608f499e89beb..75b0301072bd603e0c25357fa022b66d1df4e467 100644
--- a/tests/test_streamline.py
+++ b/tests/test_streamline.py
@@ -9,13 +9,17 @@ import torch
 from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
-import finn.transformation.batchnorm_to_affine as ba
-import finn.transformation.fold_constants as fc
-import finn.transformation.general as tg
-import finn.transformation.infer_datatypes as di
-import finn.transformation.infer_shapes as si
 import finn.transformation.streamline as sl
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.batchnorm_to_affine import BatchNormToAffine
+from finn.transformation.fold_constants import FoldConstants
+from finn.transformation.general import (
+    ConvertSubToAdd,
+    GiveReadableTensorNames,
+    GiveUniqueNodeNames
+)
+from finn.transformation.infer_datatypes import InferDataTypes
+from finn.transformation.infer_shapes import InferShapes
 
 export_onnx_path = "test_output_lfc.onnx"
 # TODO get from config instead, hardcoded to Docker path for now
@@ -30,10 +34,10 @@ def test_streamline_lfc_w1a1():
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
-    model = model.transform_single(si.infer_shapes)
-    model = model.transform_repeated(fc.fold_constants)
-    model = model.transform_single(tg.give_unique_node_names)
-    model = model.transform_single(tg.give_readable_tensor_names)
+    model = model.transform(InferShapes())
+    model = model.transform(FoldConstants())
+    model = model.transform(GiveUniqueNodeNames())
+    model = model.transform(GiveReadableTensorNames())
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
@@ -42,26 +46,26 @@ def test_streamline_lfc_w1a1():
     expected_ctx = oxe.execute_onnx(model, input_dict, True)
     expected = expected_ctx[model.graph.output[0].name]
     transforms = [
-        tg.convert_sub_to_add,
-        ba.batchnorm_to_affine,
-        sl.convert_sign_to_thres,
-        sl.move_scalar_add_past_matmul,
-        sl.move_scalar_mul_past_matmul,
-        sl.move_add_past_mul,
-        sl.collapse_repeated_add,
-        sl.collapse_repeated_mul,
-        sl.absorb_add_into_multi_threshold,
-        sl.factor_out_mul_sign_magnitude,
-        sl.absorb_mul_into_multi_threshold,
-        sl.absorb_1bit_mul_into_matmul,
-        sl.round_thresholds,
+        ConvertSubToAdd(),
+        BatchNormToAffine(),
+        sl.ConvertSignToThres(),
+        sl.MoveScalarAddPastMatMul(),
+        sl.MoveScalarMulPastMatMul(),
+        sl.MoveAddPastMul(),
+        sl.CollapseRepeatedAdd(),
+        sl.CollapseRepeatedMul(),
+        sl.AbsorbAddIntoMultiThreshold(),
+        sl.FactorOutMulSignMagnitude(),
+        sl.AbsorbMulIntoMultiThreshold(),
+        sl.Absorb1BitMulIntoMatMul(),
+        sl.RoundThresholds(),
     ]
     trn_ind = 0
     for trn in transforms:
-        model = model.transform_repeated(trn)
-        model = model.transform_single(tg.give_unique_node_names)
-        model = model.transform_single(tg.give_readable_tensor_names)
-        model = model.transform_repeated(di.infer_datatypes)
+        model = model.transform(trn)
+        model = model.transform(GiveUniqueNodeNames())
+        model = model.transform(GiveReadableTensorNames())
+        model = model.transform(InferDataTypes())
         produced_ctx = oxe.execute_onnx(model, input_dict, True)
         produced = produced_ctx[model.graph.output[0].name]
         # model.save("%d-%s.onnx" % (trn_ind, trn.__name__))
diff --git a/tests/test_topology_checks.py b/tests/test_topology_checks.py
index 2e9582a8dc23fe8ce5b3e409f19e8db7a332e9f7..e28ceac09ecfdd242285f6b9e355bd4c2cfd7e68 100644
--- a/tests/test_topology_checks.py
+++ b/tests/test_topology_checks.py
@@ -4,8 +4,8 @@ import onnx.helper as oh
 from onnx import TensorProto
 
 import finn.analysis.topology as ta
-import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
+from finn.transformation.infer_shapes import InferShapes
 
 
 def test_all_tensors_f32():
@@ -26,7 +26,7 @@ def test_all_tensors_f32():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.all_tensors_f32)
     assert ret["all_tensors_f32"] is True
 
@@ -47,7 +47,7 @@ def test_all_tensors_f32():
         )
     )
     model = ModelWrapper(modelproto)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.all_tensors_f32)
     assert ret["all_tensors_f32"] is False
 
@@ -55,7 +55,7 @@ def test_all_tensors_f32():
 def test_node_inputs_in_expected_order():
     raw_m = get_data("finn", "data/onnx/mnist-conv/model.onnx")
     model = ModelWrapper(raw_m)
-    model = model.transform_single(si.infer_shapes)
+    model = model.transform(InferShapes())
     ret = model.analysis(ta.node_inputs_in_expected_order)
     # this model has an (unnecessary) dynamic reshape for its weight tensor
     # and so it fails the check