Skip to content
Snippets Groups Projects
Commit 4e377369 authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

[fpgadataflow] Add ext weight count check

parent fe1d103b
No related branches found
No related tags found
No related merge requests found
......@@ -156,6 +156,12 @@ class FINNExampleOverlay(Overlay):
self.external_weights +=[(iwdma,weight_buf)]
if "number_of_external_weights" in self._io_shape_dict:
hw_ext_weights = self._io_shape_dict["number_of_external_weights"]
assert len(self.external_weights) == hw_ext_weights, (
"Number of hardware external weights and number of external " +
"weight tensors available do not match. \n"+
"Is runtime_weight_dir pointing to the correct folder?")
def load_runtime_weights(self, flush_accel=True, verify=True):
......
......@@ -123,6 +123,7 @@ class MakePYNQDriver(Transformation):
os.makedirs(weights_dir)
idma_idx = 0
ext_weight_dma_cnt = 0
for node in model.graph.node:
assert node.op_type == "StreamingDataflowPartition", (
......@@ -134,6 +135,7 @@ class MakePYNQDriver(Transformation):
if producer is None : # input dma?
idma_name = "idma" + str(idma_idx)
if init_tensor is not None: # input weights dma?
ext_weight_dma_cnt += 1
w_dtype = model.get_tensor_datatype(node.input[0])
init_external_tensor = to_external_tensor(init_tensor,w_dtype)
np.save(weights_dir+"/"+ idma_name+".npy",init_external_tensor)
......@@ -169,7 +171,8 @@ class MakePYNQDriver(Transformation):
driver = driver.replace("$OUTPUT_SHAPE_NORMAL$", mss(o_tensor_shape_normal))
driver = driver.replace("$OUTPUT_SHAPE_FOLDED$", mss(o_tensor_shape_folded))
driver = driver.replace("$OUTPUT_SHAPE_PACKED$", mss(o_tensor_shape_packed))
driver = driver.replace("$INPUT_DMA_NAME$", "'%s'" %net_input_name)
driver = driver.replace("$INPUT_DMA_NAME$", "'%s'" % net_input_name)
driver = driver.replace("$EXT_WEIGHT_NUM$", str(ext_weight_dma_cnt) )
with open(driver_py, "w") as f:
f.write(driver)
......
......@@ -78,7 +78,8 @@ io_shape_dict = {
"oshape_folded" : $OUTPUT_SHAPE_FOLDED$,
"ishape_packed" : $INPUT_SHAPE_PACKED$,
"oshape_packed" : $OUTPUT_SHAPE_PACKED$,
"input_dma_name" : $INPUT_DMA_NAME$
"input_dma_name" : $INPUT_DMA_NAME$,
"number_of_external_weights": $EXT_WEIGHT_NUM$
}
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment