Skip to content
Snippets Groups Projects
Commit 15d76acb authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

Merge branch 'feature/refactor_custom_op_infra' into dev

parents b1fe0302 e1e26246
No related branches found
No related tags found
No related merge requests found
# import onnx.helper as helper
import finn.core.multithreshold as multiThresh
from finn.core.utils import get_by_name
import finn.custom_op.registry as registry
def execute_custom_node(node, context, graph):
"""Call custom implementation to execute a single custom node.
Input/output provided via context."""
if node.op_type == "MultiThreshold":
# save inputs
v = context[node.input[0]]
thresholds = context[node.input[1]]
# retrieve attributes if output scaling is used
try:
out_scale = get_by_name(node.attribute, "out_scale").f
except AttributeError:
out_scale = None
try:
out_bias = get_by_name(node.attribute, "out_bias").f
except AttributeError:
out_bias = None
# calculate output
output = multiThresh.execute(v, thresholds, out_scale, out_bias)
# setting context according to output
context[node.output[0]] = output
else:
op_type = node.op_type
try:
# lookup op_type in registry of CustomOps
inst = registry.custom_op[op_type]()
inst.execute_node(node, context, graph)
except KeyError:
# exception if op_type is not supported
raise Exception("This custom node is currently not supported.")
raise Exception("Custom op_type %s is currently not supported." % op_type)
import numpy as np
def compare(x, y):
if x >= y:
return 1.0
else:
return 0.0
def execute(v, thresholds, out_scale=None, out_bias=None):
# the inputs are expected to be in the shape (N,C,H,W)
# N : Batch size
# C : Number of channels
# H : Heigth of the input images
# W : Width of the input images
#
# the thresholds are expected to be in the shape (C, B)
# C : Number of channels (must be the same value as C in input tensor or 1
# if all channels use the same threshold value)
# B : Desired activation steps => i.e. for 4-bit activation, B=7 (2^(n)-1 and n=4)
# the output tensor will be scaled by out_scale and biased by out_bias
# assert threshold shape
is_global_threshold = thresholds.shape[0] == 1
assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
# save the required shape sizes for the loops (N, C and B)
num_batch = v.shape[0]
num_channel = v.shape[1]
num_act = thresholds.shape[1]
# reshape inputs to enable channel-wise reading
vr = v.reshape((v.shape[0], v.shape[1], -1))
# save the new shape size of the images
num_img_elem = vr.shape[2]
# initiate output tensor
ret = np.zeros_like(vr)
# iterate over thresholds channel-wise
for t in range(num_channel):
channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
# iterate over batches
for b in range(num_batch):
# iterate over image elements on which the thresholds should be applied
for elem in range(num_img_elem):
# iterate over the different thresholds that correspond to one channel
for a in range(num_act):
# apply successive thresholding to every element of one channel
ret[b][t][elem] += compare(vr[b][t][elem], channel_thresh[a])
if out_scale is None:
out_scale = 1.0
if out_bias is None:
out_bias = 0.0
return out_scale * ret.reshape(v.shape) + out_bias
from abc import ABC, abstractmethod
class CustomOp(ABC):
def __init__(self):
super().__init__()
@abstractmethod
def make_shape_compatible_op(self, node):
pass
@abstractmethod
def infer_node_datatype(self, node, model):
pass
@abstractmethod
def execute_node(self, node, context, graph):
pass
import numpy as np
import onnx.helper as helper
from finn.core.datatype import DataType
from finn.core.utils import get_by_name
from finn.custom_op import CustomOp
class MultiThreshold(CustomOp):
def make_shape_compatible_op(self, node):
return helper.make_node("Relu", [node.input[0]], [node.output[0]])
def infer_node_datatype(self, node, model):
try:
odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
model.set_tensor_datatype(node.output[0], DataType[odt])
except AttributeError:
# number of thresholds decides # output bits
# use get_smallest_possible, assuming unsigned
n_thres = model.get_tensor_shape(node.input[1])[1]
odtype = DataType.get_smallest_possible(n_thres)
model.set_tensor_datatype(node.output[0], odtype)
def execute_node(self, node, context, graph):
# save inputs
v = context[node.input[0]]
thresholds = context[node.input[1]]
# retrieve attributes if output scaling is used
try:
out_scale = get_by_name(node.attribute, "out_scale").f
except AttributeError:
out_scale = None
try:
out_bias = get_by_name(node.attribute, "out_bias").f
except AttributeError:
out_bias = None
# calculate output
output = self._execute(v, thresholds, out_scale, out_bias)
# setting context according to output
context[node.output[0]] = output
def _compare(self, x, y):
if x >= y:
return 1.0
else:
return 0.0
def _execute(self, v, thresholds, out_scale=None, out_bias=None):
# the inputs are expected to be in the shape (N,C,H,W)
# N : Batch size
# C : Number of channels
# H : Heigth of the input images
# W : Width of the input images
#
# the thresholds are expected to be in the shape (C, B)
# C : Number of channels (must be the same value as C in input tensor
# or 1 if all channels use the same threshold value)
# B : Desired activation steps => i.e. for 4-bit activation,
# B=7 (2^(n)-1 and n=4)
# the output tensor will be scaled by out_scale and biased by out_bias
# assert threshold shape
is_global_threshold = thresholds.shape[0] == 1
assert (v.shape[1] == thresholds.shape[0]) or is_global_threshold
# save the required shape sizes for the loops (N, C and B)
num_batch = v.shape[0]
num_channel = v.shape[1]
num_act = thresholds.shape[1]
# reshape inputs to enable channel-wise reading
vr = v.reshape((v.shape[0], v.shape[1], -1))
# save the new shape size of the images
num_img_elem = vr.shape[2]
# initiate output tensor
ret = np.zeros_like(vr)
# iterate over thresholds channel-wise
for t in range(num_channel):
channel_thresh = thresholds[0] if is_global_threshold else thresholds[t]
# iterate over batches
for b in range(num_batch):
# iterate over image elements on which the thresholds will be applied
for elem in range(num_img_elem):
# iterate over the different thresholds for one channel
for a in range(num_act):
# apply successive thresholding to every element
ret[b][t][elem] += self._compare(
vr[b][t][elem], channel_thresh[a]
)
if out_scale is None:
out_scale = 1.0
if out_bias is None:
out_bias = 0.0
return out_scale * ret.reshape(v.shape) + out_bias
# make sure new CustomOp subclasses are imported here so that they get
# registered and plug in correctly into the infrastructure
from finn.custom_op.multithreshold import MultiThreshold
# create a mapping of all known CustomOp names and classes
custom_op = {}
custom_op["MultiThreshold"] = MultiThreshold
import finn.custom_op.registry as registry
from finn.core.datatype import DataType
from finn.core.utils import get_by_name
from finn.transformation import Transformation
......@@ -8,36 +8,37 @@ def _infer_node_datatype(model, node):
changes were made."""
idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input))
odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
if node.op_type == "MultiThreshold":
op_type = node.op_type
if node.domain == "finn":
# handle DataType inference for CustomOp
try:
odt = get_by_name(node.attribute, "out_dtype").s.decode("utf-8")
model.set_tensor_datatype(node.output[0], DataType[odt])
except AttributeError:
# number of thresholds decides # output bits
# use get_smallest_possible, assuming unsigned
n_thres = model.get_tensor_shape(node.input[1])[1]
odtype = DataType.get_smallest_possible(n_thres)
model.set_tensor_datatype(node.output[0], odtype)
elif node.op_type == "Sign":
# always produces bipolar outputs
model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
elif node.op_type == "MatMul":
if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
# node has at least one float input, output is also float
model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
else:
# TODO compute minimum / maximum result to minimize bitwidth
# use (u)int32 accumulators for now
has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
if has_signed_inp:
odtype = DataType.INT32
else:
odtype = DataType.UINT32
model.set_tensor_datatype(node.output[0], odtype)
# lookup op_type in registry of CustomOps
inst = registry.custom_op[op_type]()
inst.infer_node_datatype(node, model)
except KeyError:
# exception if op_type is not supported
raise Exception("Custom op_type %s is currently not supported." % op_type)
else:
# unknown, assume node produces float32 outputs
for o in node.output:
model.set_tensor_datatype(o, DataType.FLOAT32)
if node.op_type == "Sign":
# always produces bipolar outputs
model.set_tensor_datatype(node.output[0], DataType.BIPOLAR)
elif node.op_type == "MatMul":
if len(list(filter(lambda x: x == DataType.FLOAT32, idtypes))) != 0:
# node has at least one float input, output is also float
model.set_tensor_datatype(node.output[0], DataType.FLOAT32)
else:
# TODO compute minimum / maximum result to minimize bitwidth
# use (u)int32 accumulators for now
has_signed_inp = len(list(filter(lambda x: x.signed(), idtypes))) != 0
if has_signed_inp:
odtype = DataType.INT32
else:
odtype = DataType.UINT32
model.set_tensor_datatype(node.output[0], odtype)
else:
# unknown, assume node produces float32 outputs
for o in node.output:
model.set_tensor_datatype(o, DataType.FLOAT32)
# compare old and new output dtypes to see if anything changed
new_odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output))
graph_modified = new_odtypes != odtypes
......
import onnx.helper as helper
import onnx.shape_inference as si
import finn.custom_op.registry as registry
from finn.core.modelwrapper import ModelWrapper
from finn.transformation import Transformation
......@@ -9,10 +9,14 @@ def _make_shape_compatible_op(node):
"""Return a shape-compatible non-FINN op for a given FINN op. Used for
shape inference with custom ops."""
assert node.domain == "finn"
if node.op_type == "MultiThreshold":
return helper.make_node("Relu", [node.input[0]], [node.output[0]])
else:
raise Exception("No known shape-compatible op for %s" % node.op_type)
op_type = node.op_type
try:
# lookup op_type in registry of CustomOps
inst = registry.custom_op[op_type]()
return inst.make_shape_compatible_op(node)
except KeyError:
# exception if op_type is not supported
raise Exception("Custom op_type %s is currently not supported." % op_type)
def _hide_finn_ops(model):
......
import numpy as np
import finn.core.multithreshold as multi_thresh
from finn.custom_op.multithreshold import MultiThreshold
def test_execute_multi_thresholding():
......@@ -194,10 +194,11 @@ def test_execute_multi_thresholding():
),
)
results = multi_thresh.execute(inputs, thresholds)
multi_thresh = MultiThreshold()
results = multi_thresh._execute(inputs, thresholds)
assert (results == outputs).all()
results_scaled = multi_thresh.execute(inputs, thresholds, 2.0, -1.0)
results_scaled = multi_thresh._execute(inputs, thresholds, 2.0, -1.0)
outputs_scaled = 2.0 * outputs - 1.0
assert (results_scaled == outputs_scaled).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment