Skip to content
Snippets Groups Projects
Commit 8e5b8bac authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] fix InferBinaryStreamingFCLayer shapes and attrs

parent e7e75d25
No related branches found
No related tags found
No related merge requests found
import onnx.helper as oh
from onnx import helper
from finn.core.datatype import DataType
from finn.transformation import Transformation
......@@ -35,40 +36,43 @@ class InferBinaryStreamingFCLayer(Transformation):
# create node with no parallelization first
pe = 1
simd = 1
wmem = int(mw * mh)
# extract threshold shape
tmem = int(mh / pe)
n_thres = T.shape[1]
assert mh % pe == 0
assert mw % simd == 0
wmem = mw * mh // (pe * simd)
assert mw * mh == wmem * pe * simd
nf = mh // pe
tmem = nf
assert T.shape[0] == 1 or T.shape[0] == mh
assert n_thres == 1
# W is expected to be (PE, WMEM, SIMD)
# transpose first to meet finn-hlslib assumptions
W_new = W.transpose().reshape(pe, wmem, simd)
model.set_initializer(mm_weight, W_new)
# T is expected to be (NF, PE, n_thres)
# TODO need to double-check the threshold shape here
T_new = T.reshape(pe, tmem, n_thres)
model.set_initializer(mt_thres, T_new)
# reshape input and output tensors to expected shape
# input is expected to be (1, mw/simd, simd)
# output is expected to be (1, mh/pe, pe)
in_shape = [1, int(mw / simd), simd]
out_shape = [1, int(mh / pe), pe]
idt = DataType.BINARY
wdt = DataType.BINARY
odt = model.get_tensor_datatype(mt_output)
if odt.bitwidth() == 1:
# covers both bipolar and binary
actval = 0
else:
actval = odt.min()
in_shape = [1, mw]
out_shape = [1, mh]
model.set_tensor_shape(mm_input, in_shape)
model.set_tensor_shape(mt_output, out_shape)
# create and insert new StreamingFCLayer node
new_node = oh.make_node(
new_node = helper.make_node(
"StreamingFCLayer_Batch",
[mm_input, mm_weight, mt_thres],
[mt_output],
domain="finn",
backend="fpgadataflow",
MH=mh,
MW=mw,
PE=1,
SIMD=1,
resDataType="Recast<XnorMul>",
resType="ap_resource_lut()",
MW=mw,
MH=mh,
SIMD=simd,
PE=pe,
WMEM=wmem,
TMEM=tmem,
inputDataType=idt.name,
weightDataType=wdt.name,
outputDataType=odt.name,
ActVal=actval,
)
graph.node.insert(node_ind, new_node)
# remove old nodes
......
......@@ -37,17 +37,17 @@ def test_convert_to_hls_layers_lfc_w1a1():
model = model.transform(to_hls.InferBinaryStreamingFCLayer())
fc0 = model.graph.node[2]
assert fc0.op_type == "StreamingFCLayer_Batch"
assert model.get_tensor_shape(fc0.input[0]) == [1, 784, 1]
assert model.get_tensor_shape(fc0.input[1]) == [1, 784 * 1024, 1]
assert model.get_tensor_shape(fc0.input[2]) == [1, 1024, 1]
assert model.get_tensor_shape(fc0.input[0]) == [1, 784]
assert model.get_tensor_shape(fc0.input[1]) == [784, 1024]
assert model.get_tensor_shape(fc0.input[2]) == [1024, 1]
fc1 = model.graph.node[3]
assert fc1.op_type == "StreamingFCLayer_Batch"
assert model.get_tensor_shape(fc1.input[0]) == [1, 1024, 1]
assert model.get_tensor_shape(fc1.input[1]) == [1, 1024 * 1024, 1]
assert model.get_tensor_shape(fc1.input[2]) == [1, 1024, 1]
assert model.get_tensor_shape(fc1.input[0]) == [1, 1024]
assert model.get_tensor_shape(fc1.input[1]) == [1024, 1024]
assert model.get_tensor_shape(fc1.input[2]) == [1024, 1]
fc2 = model.graph.node[4]
assert fc2.op_type == "StreamingFCLayer_Batch"
assert model.get_tensor_shape(fc2.input[0]) == [1, 1024, 1]
assert model.get_tensor_shape(fc2.input[1]) == [1, 1024 * 1024, 1]
assert model.get_tensor_shape(fc2.input[2]) == [1, 1024, 1]
assert model.get_tensor_shape(fc2.input[0]) == [1, 1024]
assert model.get_tensor_shape(fc2.input[1]) == [1024, 1024]
assert model.get_tensor_shape(fc2.input[2]) == [1024, 1]
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment