diff --git a/src/finn/transformation/fpgadataflow/derive_characteristic.py b/src/finn/transformation/fpgadataflow/derive_characteristic.py index a9b291ba5b086a31855079233a00164e190cfee9..f857cdb5ef7328bec34b798010d3fbb167a61208 100644 --- a/src/finn/transformation/fpgadataflow/derive_characteristic.py +++ b/src/finn/transformation/fpgadataflow/derive_characteristic.py @@ -31,6 +31,7 @@ import numpy as np import qonnx.custom_op.registry as registry import warnings from pyverilator.util.axi_utils import _read_signal, reset_rtlsim, rtlsim_multi_io +from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import NodeLocalTransformation from finn.util.fpgadataflow import is_fpgadataflow_node @@ -80,9 +81,9 @@ class DeriveCharacteristic(NodeLocalTransformation): "DuplicateStreams_Batch", "StreamingConcat", ] - assert ( - node.op_type not in multistream_optypes - ), f"{node.name} unsupported" + if node.op_type in multistream_optypes: + warnings.warn(f"Skipping {node.name} for rtlsim characterization") + return (node, False) exp_cycles = inst.get_exp_cycles() n_inps = np.prod(inst.get_folded_input_shape()[:-1]) n_outs = np.prod(inst.get_folded_output_shape()[:-1]) @@ -183,6 +184,57 @@ class DeriveCharacteristic(NodeLocalTransformation): ) return (node, False) + def apply(self, model: ModelWrapper): + (model, run_again) = super().apply(model) + # apply manual fix for DuplicateStreams and AddStreams for + # simple residual reconvergent paths with bypass + addstrm_nodes = model.get_nodes_by_op_type("AddStreams") + for addstrm_node in addstrm_nodes: + # we currently only support the case where one branch is + # a bypass + b0 = model.find_producer(addstrm_node.input[0]) + b1 = model.find_producer(addstrm_node.input[1]) + if (b0 is None) or (b1 is None): + warnings.warn("Found unsupported AddStreams, skipping") + return (model, run_again) + b0_is_bypass = b0.op_type == "DuplicateStreams" + b1_is_bypass = b1.op_type == "DuplicateStreams" + if (not b0_is_bypass) and (not b1_is_bypass): + warnings.warn("Found unsupported AddStreams, skipping") + return (model, run_again) + ds_node = b0 if b0_is_bypass else b1 + comp_branch_last = b1 if b0_is_bypass else b0 + + ds_comp_bout = ds_node.output[0] if b0_is_bypass else ds_node.output[1] + comp_branch_first = model.find_consumer(ds_comp_bout) + if comp_branch_first is None or comp_branch_last is None: + warnings.warn("Found unsupported DuplicateStreams, skipping") + return (model, run_again) + comp_branch_last = registry.getCustomOp(comp_branch_last) + comp_branch_first = registry.getCustomOp(comp_branch_first) + # for DuplicateStreams, use comp_branch_first's input characterization + # for AddStreams, use comp_branch_last's output characterization + period = comp_branch_first.get_nodeattr("io_characteristic_period") + comp_branch_first_f = comp_branch_first.get_nodeattr("io_characteristic")[ + : 2 * period + ] + comp_branch_last_f = comp_branch_last.get_nodeattr("io_characteristic")[ + 2 * period : + ] + ds_node_inst = registry.getCustomOp(ds_node) + addstrm_node_inst = registry.getCustomOp(addstrm_node) + ds_node_inst.set_nodeattr("io_characteristic_period", period) + ds_node_inst.set_nodeattr("io_characteristic", comp_branch_first_f * 2) + addstrm_node_inst.set_nodeattr("io_characteristic_period", period) + addstrm_node_inst.set_nodeattr("io_characteristic", comp_branch_last_f * 2) + warnings.warn( + f"Set {ds_node.name} chrc. from {comp_branch_first.onnx_node.name}" + ) + warnings.warn( + f"Set {addstrm_node.name} chrc. from {comp_branch_last.onnx_node.name}" + ) + return (model, run_again) + class DeriveFIFOSizes(NodeLocalTransformation): """Prerequisite: DeriveCharacteristic already called on graph.