diff --git a/src/finn/util/pytorch.py b/src/finn/util/pytorch.py index 89c773ef1dc78ac9d35e00432f4d8e78fa0a8622..8332757cab839c8ce2fe7afa2449da5782d1aea3 100644 --- a/src/finn/util/pytorch.py +++ b/src/finn/util/pytorch.py @@ -25,6 +25,46 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import torch + +from torch.nn import Module, Sequential +from brevitas.quant_tensor import QuantTensor + + +class Normalize(Module): + def __init__(self, mean, std, channels): + super(Normalize, self).__init__() + + self.mean = mean + self.std = std + self.channels = channels + + def forward(self, x): + x = x - torch.tensor(self.mean, device=x.device).reshape(1, self.channels, 1, 1) + x = x / self.std + return x + + +class ToTensor(Module): + def __init__(self): + super(ToTensor, self).__init__() + + def forward(self, x): + x = x / 255 + return x + + +class NormalizePreProc(Module): + def __init__(self, mean, std, channels): + super(NormalizePreProc, self).__init__() + self.features = Sequential() + scaling = ToTensor() + self.features.add_module("scaling", scaling) + normalize = Normalize(mean, std, channels) + self.features.add_module("normalize", normalize) + + def forward(self, x): + return self.features(x) class BrevitasDebugHook: @@ -32,7 +72,10 @@ class BrevitasDebugHook: self.outputs = {} def __call__(self, module, module_in, module_out): - self.outputs[module.export_debug_name] = module_out.detach().numpy() + tensor = module_out + if isinstance(module_out, QuantTensor): + tensor = module_out[0] + self.outputs[module.export_debug_name] = tensor.detach().numpy() def clear(self): self.outputs = {}