Skip to content
Snippets Groups Projects
Commit 6c49cda9 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[InsertIODMA] refactoring: use existing util functions

parent 75d7b82f
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,6 @@ from finn.util.basic import get_by_name
from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation
from finn.transformation.general import SortGraph
import finn.core.data_layout as DataLayout
import math
import numpy as np
......@@ -59,6 +58,10 @@ class InsertIODMA(Transformation):
addr = 1: [(pe-1,simd*2-1),.......(0,simd+1),(0,simd)]
.
"""
# TODO: refactor this into streamingfclayer_batch.py, could go into
# make_weight_file except it doesn't write a file but returns a npy
# array instead
w_shape = weights.shape
assert len(w_shape) == 2, "weights withincorrect number of dims"
inp_w, out_w = w_shape
......@@ -70,13 +73,15 @@ class InsertIODMA(Transformation):
addr = 0
for fr in range(out_w // pe):
for fc in range(inp_w // simd):
tile = weights[
(fc * simd) : ((fc + 1) * simd), (fr * pe) : ((fr + 1) * pe)
]
w0_lower = fc * simd
w0_upper = (fc + 1) * simd
w1_lower = fr * pe
w1_upper = (fr + 1) * pe
tile = weights[w0_lower:w0_upper, w1_lower:w1_upper]
for p in range(pe):
reshaped_w[addr, (p * simd) : ((p + 1) * simd)] = tile[
:, p
].transpose()
rw0_lower = p * simd
rw0_upper = (p + 1) * simd
reshaped_w[addr, rw0_lower:rw0_upper] = tile[:, p].transpose()
addr += 1
reshaped_w = np.flip(reshaped_w, axis=-1)
return reshaped_w
......@@ -92,8 +97,7 @@ class InsertIODMA(Transformation):
fc_extw_nodes = list(
filter(
lambda x: x.op_type == "StreamingFCLayer_Batch"
and get_by_name(x.attribute, "mem_mode") is not None
and get_by_name(x.attribute, "mem_mode").s.decode("UTF-8") == "external"
and getCustomOp(x).get_nodeattr("mem_mode") == "external"
and model.find_producer(x.input[1]) is None,
all_nodes,
)
......@@ -191,6 +195,7 @@ class InsertIODMA(Transformation):
)
model.graph.node.insert(0, dma_node)
for fc_node in fc_extw_nodes:
fc_inst = getCustomOp(fc_node)
fc_w_name = fc_node.input[1]
w_shape = model.get_tensor_shape(fc_w_name)
w_dtype = model.get_tensor_datatype(fc_w_name)
......@@ -203,7 +208,7 @@ class InsertIODMA(Transformation):
# calculate width of stream output from DMA
pe = get_by_name(fc_node.attribute, "PE").i
simd = get_by_name(fc_node.attribute, "SIMD").i
streamWidth = simd * pe * w_dtype.bitwidth()
streamWidth = fc_inst.get_weightstream_width_padded()
# make new buffer
W = model.get_initializer(fc_w_name)
iodma_mem = self.get_mem_init(W, pe, simd)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment