Skip to content
Snippets Groups Projects
Unverified Commit 01fdd9c5 authored by Yaman Umuroglu's avatar Yaman Umuroglu Committed by GitHub
Browse files

Merge pull request #143 from quetric/feature/fix_globalaccpool_batch_shape

Feature/fix globalAccpool batch shape
parents bdb56b28 9b3b0162
No related branches found
No related tags found
No related merge requests found
...@@ -75,16 +75,19 @@ class GlobalAccPool_Batch(HLSCustomOp): ...@@ -75,16 +75,19 @@ class GlobalAccPool_Batch(HLSCustomOp):
def get_normal_output_shape(self): def get_normal_output_shape(self):
ch = self.get_nodeattr("NumChannels") ch = self.get_nodeattr("NumChannels")
vecs = list(self.get_nodeattr("numInputVectors")) vecs = list(self.get_nodeattr("numInputVectors"))
oshape = tuple([vecs[0]] + [ch]) if len(vecs) == 1:
oshape = tuple(vecs + [ch])
elif len(vecs) == 3:
oshape = tuple([vecs[0]] + [1, 1, ch])
return oshape return oshape
def get_folded_output_shape(self): def get_folded_output_shape(self):
ch = self.get_nodeattr("NumChannels") ch = self.get_nodeattr("NumChannels")
pe = self.get_nodeattr("PE") pe = self.get_nodeattr("PE")
vecs = list(self.get_nodeattr("numInputVectors")) unfolded_shape = list(self.get_normal_output_shape())
assert ch % pe == 0, "PE must divide NumChannels" assert ch % pe == 0, "PE must divide NumChannels"
folds = int(ch / pe) folds = int(ch / pe)
oshape = tuple([vecs[0]] + [folds, pe]) oshape = tuple(unfolded_shape[:-1] + [folds, pe])
return oshape return oshape
def make_shape_compatible_op(self, model): def make_shape_compatible_op(self, model):
......
...@@ -785,3 +785,88 @@ class InferChannelwiseLinearLayer(Transformation): ...@@ -785,3 +785,88 @@ class InferChannelwiseLinearLayer(Transformation):
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(InferDataTypes()) model = model.transform(InferDataTypes())
return (model, graph_modified) return (model, graph_modified)
class InferGlobalAccPoolLayer(Transformation):
"""Convert any GlobalAveragePool into a GlobalAccPool HLS layer and a scalar Mul."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for node in graph.node:
node_ind += 1
if node.op_type == "GlobalAveragePool":
in0 = node.input[0]
result = node.output[0]
in0_shape = model.get_tensor_shape(in0)
idt = model.get_tensor_datatype(in0)
# skip conversion for layers with float input
if not idt.is_integer():
continue
# check layout and convert if necessary
in0_layout = model.get_tensor_layout(in0)
result_layout = model.get_tensor_layout(result)
if in0_layout == DataLayout.NCHW:
in0 = nchw_to_nhwc(in0, model, node_ind)
node_ind += 1
in0_shape = model.get_tensor_shape(in0)
# keep track of where we need to insert the HLS Op
# it has to be ahead of the output transform
insert_point = node_ind
if result_layout == DataLayout.NCHW:
result = nchw_to_nhwc(result, model, node_ind, reverse=True)
node_ind += 1
num_ch = int(in0_shape[-1])
vecs = in0_shape[:-1]
# create node with no parallelization first
pe = 1
assert (
num_ch % pe == 0
), "Requirement Labels divisable by PE is violated."
# create an additional tensor of the same shape and layout as result
out_shape = model.get_tensor_shape(result)
pool_out = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, out_shape
)
model.graph.value_info.append(pool_out)
pool_out = pool_out.name
model.set_tensor_layout(pool_out, model.get_tensor_layout(result))
new_pool = helper.make_node(
"GlobalAccPool_Batch",
[in0],
[pool_out],
domain="finn",
backend="fpgadataflow",
NumChannels=num_ch,
PE=pe,
inputDataType=idt.name,
numInputVectors=vecs,
)
mul_value = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, [1]
)
model.graph.value_info.append(mul_value)
model.set_initializer(mul_value.name, np.array(1 / (vecs[1] * vecs[2])))
new_mul = helper.make_node("Mul", [pool_out, mul_value.name], [result],)
graph.node.insert(insert_point, new_pool)
graph.node.insert(insert_point + 1, new_mul)
node_ind += 1
# remove old node
graph.node.remove(node)
graph_modified = True
if graph_modified:
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment