Skip to content
Snippets Groups Projects
Commit 8584c46e authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[HLSCustomOp] Fix computed output shape of GlobalAccPool_Batch

parent 2bbf6272
No related branches found
No related tags found
No related merge requests found
......@@ -75,16 +75,19 @@ class GlobalAccPool_Batch(HLSCustomOp):
def get_normal_output_shape(self):
ch = self.get_nodeattr("NumChannels")
vecs = list(self.get_nodeattr("numInputVectors"))
oshape = tuple([vecs[0]] + [ch])
if len(vecs) == 1:
oshape = tuple(vecs + [ch])
elif len(vecs) == 3:
oshape = tuple([vecs[0]] + [1, 1, ch])
return oshape
def get_folded_output_shape(self):
ch = self.get_nodeattr("NumChannels")
pe = self.get_nodeattr("PE")
vecs = list(self.get_nodeattr("numInputVectors"))
unfolded_shape = list(self.get_normal_output_shape())
assert ch % pe == 0, "PE must divide NumChannels"
folds = int(ch / pe)
oshape = tuple([vecs[0]] + [folds, pe])
oshape = tuple(unfolded_shape[:-1] + [folds, pe])
return oshape
def make_shape_compatible_op(self, model):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment