Skip to content
Snippets Groups Projects
Commit a6c863cd authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] test RemoveStaticGraphInputs and RemoveUnusedInitAndValueInfo

parent cad8807d
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,7 @@ import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.general import GiveUniqueNodeNames
from finn.transformation.general import GiveUniqueNodeNames, RemoveStaticGraphInputs
from finn.transformation.double_to_single_float import DoubleToSingleFloat
from finn.util.test import get_test_model_trained
......@@ -57,6 +57,9 @@ def test_brevitas_cnv_export_exec(wbits, abits):
model = model.transform(DoubleToSingleFloat())
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(RemoveStaticGraphInputs())
assert len(model.graph.input) == 1
assert len(model.graph.output) == 1
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"].astype(np.float32)
input_tensor = input_tensor / 255
......
......@@ -39,6 +39,7 @@ import torch
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import RemoveStaticGraphInputs
from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import make_build_dir
from finn.util.test import get_test_model_trained
......@@ -63,6 +64,9 @@ def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits):
model = ModelWrapper(finn_onnx)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model = model.transform(RemoveStaticGraphInputs())
assert len(model.graph.input) == 1
assert len(model.graph.output) == 1
# load one of the test vectors
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
......
......@@ -42,7 +42,12 @@ from finn.transformation.double_to_single_float import DoubleToSingleFloat
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
)
from finn.transformation.streamline import Streamline
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
......@@ -97,6 +102,7 @@ def test_end2end_cnv_w1a1_import_and_tidy():
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(RemoveStaticGraphInputs())
model.save(build_dir + "/end2end_cnv_w1a1_tidy.onnx")
......@@ -108,6 +114,7 @@ def test_end2end_cnv_w1a1_streamline():
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
model = model.transform(ConvertBipolarMatMulToXnorPopcount())
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model.save(build_dir + "/end2end_cnv_w1a1_streamlined.onnx")
......
......@@ -42,7 +42,12 @@ from finn.transformation.double_to_single_float import DoubleToSingleFloat
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
)
from finn.transformation.streamline import Streamline
from finn.transformation.lower_convs_to_matmul import LowerConvsToMatMul
import finn.transformation.streamline.absorb as absorb
......@@ -96,6 +101,7 @@ def test_end2end_cnv_w2a2_import_and_tidy():
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(RemoveStaticGraphInputs())
model.save(build_dir + "/end2end_cnv_w2a2_tidy.onnx")
......@@ -106,6 +112,7 @@ def test_end2end_cnv_w2a2_streamline():
model = model.transform(MakeMaxPoolNHWC())
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model.save(build_dir + "/end2end_cnv_w2a2_streamlined.onnx")
......
......@@ -63,7 +63,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
)
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
)
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import Streamline
......@@ -98,12 +103,14 @@ def test_end2end_tfc_w1a1_import_and_tidy():
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model = model.transform(RemoveStaticGraphInputs())
model.save(build_dir + "/end2end_tfc_w1a1_tidy.onnx")
def test_end2end_tfc_w1a1_streamline():
model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a1_tidy.onnx")
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model.save(build_dir + "/end2end_tfc_w1a1_streamlined.onnx")
......
......@@ -61,7 +61,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
)
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
)
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import Streamline
......@@ -93,12 +98,14 @@ def test_end2end_tfc_w1a2_import_and_tidy():
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model = model.transform(RemoveStaticGraphInputs())
model.save(build_dir + "/end2end_tfc_w1a2_tidy.onnx")
def test_end2end_tfc_w1a2_streamline():
model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a2_tidy.onnx")
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model.save(build_dir + "/end2end_tfc_w1a2_streamlined.onnx")
......
......@@ -61,7 +61,12 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
)
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
)
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import Streamline
......@@ -93,12 +98,14 @@ def test_end2end_tfc_w2a2_import_and_tidy():
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(InferDataTypes())
model = model.transform(RemoveStaticGraphInputs())
model.save(build_dir + "/end2end_tfc_w2a2_tidy.onnx")
def test_end2end_tfc_w2a2_streamline():
model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w2a2_tidy.onnx")
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model.save(build_dir + "/end2end_tfc_w2a2_streamlined.onnx")
......
......@@ -34,7 +34,12 @@ import pkg_resources as pk
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
)
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import Streamline
from finn.util.test import get_test_model_trained
......@@ -62,6 +67,7 @@ def test_streamline_cnv(size, wbits, abits):
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(RemoveStaticGraphInputs())
# load one of the test vectors
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"].astype(np.float32)
......@@ -73,6 +79,9 @@ def test_streamline_cnv(size, wbits, abits):
expected = expected_ctx[model.graph.output[0].name]
# model.save("orig_cnv.onnx")
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
assert len(model.graph.initializer) == 21
assert len(model.graph.value_info) == 43
# model.save("streamlined_cnv.onnx")
assert len(model.graph.node) == 23
produced_ctx = oxe.execute_onnx(model, input_dict, True)
......
......@@ -37,7 +37,12 @@ import pytest
import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
)
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.streamline import Streamline
from finn.util.test import get_test_model_trained
......@@ -65,6 +70,7 @@ def test_streamline_fc(size, wbits, abits):
model = model.transform(FoldConstants())
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())
model = model.transform(RemoveStaticGraphInputs())
# load one of the test vectors
raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
input_tensor = onnx.load_tensor_from_string(raw_i)
......@@ -73,6 +79,9 @@ def test_streamline_fc(size, wbits, abits):
expected_ctx = oxe.execute_onnx(model, input_dict, True)
expected = expected_ctx[model.graph.output[0].name]
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
assert len(model.graph.initializer) == 11
assert len(model.graph.value_info) == 21
produced_ctx = oxe.execute_onnx(model, input_dict, True)
produced = produced_ctx[model.graph.output[0].name]
assert np.isclose(expected, produced, atol=1e-3).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment