Skip to content
Snippets Groups Projects
Commit b4e09c46 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by Yaman Umuroglu
Browse files

[Util] support logging QuantTensors in forward hook

parent 487c6474
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,46 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import torch
from torch.nn import Module, Sequential
from brevitas.quant_tensor import QuantTensor
class Normalize(Module):
def __init__(self, mean, std, channels):
super(Normalize, self).__init__()
self.mean = mean
self.std = std
self.channels = channels
def forward(self, x):
x = x - torch.tensor(self.mean, device=x.device).reshape(1, self.channels, 1, 1)
x = x / self.std
return x
class ToTensor(Module):
def __init__(self):
super(ToTensor, self).__init__()
def forward(self, x):
x = x / 255
return x
class NormalizePreProc(Module):
def __init__(self, mean, std, channels):
super(NormalizePreProc, self).__init__()
self.features = Sequential()
scaling = ToTensor()
self.features.add_module("scaling", scaling)
normalize = Normalize(mean, std, channels)
self.features.add_module("normalize", normalize)
def forward(self, x):
return self.features(x)
class BrevitasDebugHook:
......@@ -32,7 +72,10 @@ class BrevitasDebugHook:
self.outputs = {}
def __call__(self, module, module_in, module_out):
self.outputs[module.export_debug_name] = module_out.detach().numpy()
tensor = module_out
if isinstance(module_out, QuantTensor):
tensor = module_out[0]
self.outputs[module.export_debug_name] = tensor.detach().numpy()
def clear(self):
self.outputs = {}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment