From b4e09c467a88fce1832c5dfdf2b2c6193af0db52 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Fri, 17 Jul 2020 12:00:23 +0100
Subject: [PATCH] [Util] support logging QuantTensors in forward hook

---
 src/finn/util/pytorch.py | 45 +++++++++++++++++++++++++++++++++++++++-
 1 file changed, 44 insertions(+), 1 deletion(-)

diff --git a/src/finn/util/pytorch.py b/src/finn/util/pytorch.py
index 89c773ef1..8332757ca 100644
--- a/src/finn/util/pytorch.py
+++ b/src/finn/util/pytorch.py
@@ -25,6 +25,46 @@
 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import torch
+
+from torch.nn import Module, Sequential
+from brevitas.quant_tensor import QuantTensor
+
+
+class Normalize(Module):
+    def __init__(self, mean, std, channels):
+        super(Normalize, self).__init__()
+
+        self.mean = mean
+        self.std = std
+        self.channels = channels
+
+    def forward(self, x):
+        x = x - torch.tensor(self.mean, device=x.device).reshape(1, self.channels, 1, 1)
+        x = x / self.std
+        return x
+
+
+class ToTensor(Module):
+    def __init__(self):
+        super(ToTensor, self).__init__()
+
+    def forward(self, x):
+        x = x / 255
+        return x
+
+
+class NormalizePreProc(Module):
+    def __init__(self, mean, std, channels):
+        super(NormalizePreProc, self).__init__()
+        self.features = Sequential()
+        scaling = ToTensor()
+        self.features.add_module("scaling", scaling)
+        normalize = Normalize(mean, std, channels)
+        self.features.add_module("normalize", normalize)
+
+    def forward(self, x):
+        return self.features(x)
 
 
 class BrevitasDebugHook:
@@ -32,7 +72,10 @@ class BrevitasDebugHook:
         self.outputs = {}
 
     def __call__(self, module, module_in, module_out):
-        self.outputs[module.export_debug_name] = module_out.detach().numpy()
+        tensor = module_out
+        if isinstance(module_out, QuantTensor):
+            tensor = module_out[0]
+        self.outputs[module.export_debug_name] = tensor.detach().numpy()
 
     def clear(self):
         self.outputs = {}
-- 
GitLab