Skip to content
Snippets Groups Projects
Commit 599d8497 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add InferStreamingMaxPool

parent 909e645b
No related branches found
No related tags found
No related merge requests found
......@@ -83,6 +83,49 @@ class InferConvInpGen(Transformation):
return (model, graph_modified)
class InferStreamingMaxPool(Transformation):
"""Convert MaxPoolNHWC layers to StreamingMaxPool layers."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "MaxPoolNHWC":
mp_input = n.input[0]
mp_output = n.output[0]
mp_in_shape = model.get_tensor_shape(mp_input)
# mp_out_shape = model.get_tensor_shape(mp_output)
dt = model.get_tensor_datatype(mp_input)
mp_inst = getCustomOp(n)
# stride = mp_inst.get_nodeattr("strides")[0]
k = mp_inst.get_nodeattr("kernel_shape")[0]
# pad = mp_inst.get_nodeattr("pads")[0]
ifm_ch = mp_in_shape[-1]
ifm_dim = mp_in_shape[1]
# ofm_dim = mp_out_shape[1]
if ifm_dim % k == 0:
# create equivalent StreamingMaxPool_Batch node
# TODO support non-k strides
new_node = helper.make_node(
"StreamingMaxPool_Batch",
[mp_input],
[mp_output],
domain="finn",
backend="fpgadataflow",
PoolDim=k,
NumChannels=ifm_ch,
ImgDim=ifm_dim,
dataType=dt.name,
)
graph.node.insert(node_ind, new_node)
# remove old nodes
graph.node.remove(n)
graph_modified = True
return (model, graph_modified)
class InferBinaryStreamingFCLayer(Transformation):
"""Convert XnorPopcountMatMul layers to
StreamingFCLayer_Batch layers. Any immediately following MultiThreshold
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment