diff --git a/src/finn/util/pytorch.py b/src/finn/util/pytorch.py
index 89c773ef1dc78ac9d35e00432f4d8e78fa0a8622..8332757cab839c8ce2fe7afa2449da5782d1aea3 100644
--- a/src/finn/util/pytorch.py
+++ b/src/finn/util/pytorch.py
@@ -25,6 +25,46 @@
 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import torch
+
+from torch.nn import Module, Sequential
+from brevitas.quant_tensor import QuantTensor
+
+
+class Normalize(Module):
+    def __init__(self, mean, std, channels):
+        super(Normalize, self).__init__()
+
+        self.mean = mean
+        self.std = std
+        self.channels = channels
+
+    def forward(self, x):
+        x = x - torch.tensor(self.mean, device=x.device).reshape(1, self.channels, 1, 1)
+        x = x / self.std
+        return x
+
+
+class ToTensor(Module):
+    def __init__(self):
+        super(ToTensor, self).__init__()
+
+    def forward(self, x):
+        x = x / 255
+        return x
+
+
+class NormalizePreProc(Module):
+    def __init__(self, mean, std, channels):
+        super(NormalizePreProc, self).__init__()
+        self.features = Sequential()
+        scaling = ToTensor()
+        self.features.add_module("scaling", scaling)
+        normalize = Normalize(mean, std, channels)
+        self.features.add_module("normalize", normalize)
+
+    def forward(self, x):
+        return self.features(x)
 
 
 class BrevitasDebugHook:
@@ -32,7 +72,10 @@ class BrevitasDebugHook:
         self.outputs = {}
 
     def __call__(self, module, module_in, module_out):
-        self.outputs[module.export_debug_name] = module_out.detach().numpy()
+        tensor = module_out
+        if isinstance(module_out, QuantTensor):
+            tensor = module_out[0]
+        self.outputs[module.export_debug_name] = tensor.detach().numpy()
 
     def clear(self):
         self.outputs = {}