Skip to content
Snippets Groups Projects
Commit 6fb6dfc9 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] use new attribute system in MultiThreshold

parent 277e2558
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,6 @@ import numpy as np
import onnx.helper as helper
from finn.core.datatype import DataType
from finn.core.utils import get_by_name
from finn.custom_op import CustomOp
......@@ -71,15 +70,8 @@ class MultiThreshold(CustomOp):
def infer_node_datatype(self, model):
node = self.onnx_node
try:
odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
model.set_tensor_datatype(node.output[0], DataType[odt])
except AttributeError:
# number of thresholds decides # output bits
# use get_smallest_possible, assuming unsigned
n_thres = model.get_tensor_shape(node.input[1])[1]
odtype = DataType.get_smallest_possible(n_thres)
model.set_tensor_datatype(node.output[0], odtype)
odt = self.get_nodeattr("out_dtype")
model.set_tensor_datatype(node.output[0], DataType[odt])
def execute_node(self, context, graph):
node = self.onnx_node
......@@ -87,14 +79,8 @@ class MultiThreshold(CustomOp):
v = context[node.input[0]]
thresholds = context[node.input[1]]
# retrieve attributes if output scaling is used
try:
out_scale = get_by_name(node.attribute, "out_scale").f
except AttributeError:
out_scale = None
try:
out_bias = get_by_name(node.attribute, "out_bias").f
except AttributeError:
out_bias = None
out_scale = self.get_nodeattr("out_scale")
out_bias = self.get_nodeattr("out_bias")
# calculate output
output = multithreshold(v, thresholds, out_scale, out_bias)
# setting context according to output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment