Skip to content
Snippets Groups Projects
Commit bce0e6ca authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[FIFO] try sizing bypass FIFOs with a new approach

parent 5356ace5
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,7 @@ import numpy as np
import qonnx.custom_op.registry as registry
import warnings
from pyverilator.util.axi_utils import _read_signal, reset_rtlsim, rtlsim_multi_io
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.base import NodeLocalTransformation
from finn.util.fpgadataflow import is_fpgadataflow_node
......@@ -80,9 +81,9 @@ class DeriveCharacteristic(NodeLocalTransformation):
"DuplicateStreams_Batch",
"StreamingConcat",
]
assert (
node.op_type not in multistream_optypes
), f"{node.name} unsupported"
if node.op_type in multistream_optypes:
warnings.warn(f"Skipping {node.name} for rtlsim characterization")
return (node, False)
exp_cycles = inst.get_exp_cycles()
n_inps = np.prod(inst.get_folded_input_shape()[:-1])
n_outs = np.prod(inst.get_folded_output_shape()[:-1])
......@@ -183,6 +184,57 @@ class DeriveCharacteristic(NodeLocalTransformation):
)
return (node, False)
def apply(self, model: ModelWrapper):
(model, run_again) = super().apply(model)
# apply manual fix for DuplicateStreams and AddStreams for
# simple residual reconvergent paths with bypass
addstrm_nodes = model.get_nodes_by_op_type("AddStreams")
for addstrm_node in addstrm_nodes:
# we currently only support the case where one branch is
# a bypass
b0 = model.find_producer(addstrm_node.input[0])
b1 = model.find_producer(addstrm_node.input[1])
if (b0 is None) or (b1 is None):
warnings.warn("Found unsupported AddStreams, skipping")
return (model, run_again)
b0_is_bypass = b0.op_type == "DuplicateStreams"
b1_is_bypass = b1.op_type == "DuplicateStreams"
if (not b0_is_bypass) and (not b1_is_bypass):
warnings.warn("Found unsupported AddStreams, skipping")
return (model, run_again)
ds_node = b0 if b0_is_bypass else b1
comp_branch_last = b1 if b0_is_bypass else b0
ds_comp_bout = ds_node.output[0] if b0_is_bypass else ds_node.output[1]
comp_branch_first = model.find_consumer(ds_comp_bout)
if comp_branch_first is None or comp_branch_last is None:
warnings.warn("Found unsupported DuplicateStreams, skipping")
return (model, run_again)
comp_branch_last = registry.getCustomOp(comp_branch_last)
comp_branch_first = registry.getCustomOp(comp_branch_first)
# for DuplicateStreams, use comp_branch_first's input characterization
# for AddStreams, use comp_branch_last's output characterization
period = comp_branch_first.get_nodeattr("io_characteristic_period")
comp_branch_first_f = comp_branch_first.get_nodeattr("io_characteristic")[
: 2 * period
]
comp_branch_last_f = comp_branch_last.get_nodeattr("io_characteristic")[
2 * period :
]
ds_node_inst = registry.getCustomOp(ds_node)
addstrm_node_inst = registry.getCustomOp(addstrm_node)
ds_node_inst.set_nodeattr("io_characteristic_period", period)
ds_node_inst.set_nodeattr("io_characteristic", comp_branch_first_f * 2)
addstrm_node_inst.set_nodeattr("io_characteristic_period", period)
addstrm_node_inst.set_nodeattr("io_characteristic", comp_branch_last_f * 2)
warnings.warn(
f"Set {ds_node.name} chrc. from {comp_branch_first.onnx_node.name}"
)
warnings.warn(
f"Set {addstrm_node.name} chrc. from {comp_branch_last.onnx_node.name}"
)
return (model, run_again)
class DeriveFIFOSizes(NodeLocalTransformation):
"""Prerequisite: DeriveCharacteristic already called on graph.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment