Skip to content
Snippets Groups Projects
Commit 245cefe1 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] rename to RemoveUnusedTensors, also rm annotations + test

parent a6c863cd
No related branches found
No related tags found
No related merge requests found
......@@ -31,8 +31,11 @@ from finn.transformation import Transformation
from toposort import toposort_flatten
class RemoveUnusedInitAndValueInfo(Transformation):
"Remove any unused initializers and value_info in the graph."
class RemoveUnusedTensors(Transformation):
"""Remove any unused tensors in the graph by removing any initializers,
ValueInfo and tensor annotations associated with it. Unused tensors do not
appear as any input/output for any graph nodes.
"""
def apply(self, model):
graph_modified = False
......@@ -44,7 +47,8 @@ class RemoveUnusedInitAndValueInfo(Transformation):
used_tensors.add(i)
for o in node.output:
used_tensors.add(o)
# remove initializers and value_info not in the used set
# remove initializers, value_info and annotations that are not in the
# used set of tensors, as determined by the graph node i/o
for init in onnx_graph.initializer:
if init.name not in used_tensors:
onnx_graph.initializer.remove(init)
......@@ -53,6 +57,10 @@ class RemoveUnusedInitAndValueInfo(Transformation):
if vi.name not in used_tensors:
onnx_graph.value_info.remove(vi)
graph_modified = True
for qa in onnx_graph.quantization_annotation:
if qa.tensor_name not in used_tensors:
onnx_graph.quantization_annotation.remove(qa)
graph_modified = True
return (model, graph_modified)
......
......@@ -43,7 +43,7 @@ from finn.transformation.infer_shapes import InferShapes
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveUnusedTensors,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
......@@ -114,7 +114,7 @@ def test_end2end_cnv_w1a1_streamline():
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
model = model.transform(ConvertBipolarMatMulToXnorPopcount())
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model = model.transform(RemoveUnusedTensors())
model.save(build_dir + "/end2end_cnv_w1a1_streamlined.onnx")
......
......@@ -43,7 +43,7 @@ from finn.transformation.infer_shapes import InferShapes
from finn.transformation.move_reshape import RemoveCNVtoFCFlatten
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveUnusedTensors,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
......@@ -112,7 +112,7 @@ def test_end2end_cnv_w2a2_streamline():
model = model.transform(MakeMaxPoolNHWC())
model = model.transform(absorb.AbsorbTransposeIntoMultiThreshold())
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model = model.transform(RemoveUnusedTensors())
model.save(build_dir + "/end2end_cnv_w2a2_streamlined.onnx")
......
......@@ -64,7 +64,7 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveUnusedTensors,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
......@@ -110,7 +110,7 @@ def test_end2end_tfc_w1a1_import_and_tidy():
def test_end2end_tfc_w1a1_streamline():
model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a1_tidy.onnx")
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model = model.transform(RemoveUnusedTensors())
model.save(build_dir + "/end2end_tfc_w1a1_streamlined.onnx")
......
......@@ -62,7 +62,7 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveUnusedTensors,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
......@@ -105,7 +105,7 @@ def test_end2end_tfc_w1a2_import_and_tidy():
def test_end2end_tfc_w1a2_streamline():
model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w1a2_tidy.onnx")
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model = model.transform(RemoveUnusedTensors())
model.save(build_dir + "/end2end_tfc_w1a2_streamlined.onnx")
......
......@@ -62,7 +62,7 @@ from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveUnusedTensors,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
......@@ -105,7 +105,7 @@ def test_end2end_tfc_w2a2_import_and_tidy():
def test_end2end_tfc_w2a2_streamline():
model = load_test_checkpoint_or_skip(build_dir + "/end2end_tfc_w2a2_tidy.onnx")
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model = model.transform(RemoveUnusedTensors())
model.save(build_dir + "/end2end_tfc_w2a2_streamlined.onnx")
......
......@@ -35,7 +35,7 @@ import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveUnusedTensors,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
......@@ -79,7 +79,7 @@ def test_streamline_cnv(size, wbits, abits):
expected = expected_ctx[model.graph.output[0].name]
# model.save("orig_cnv.onnx")
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model = model.transform(RemoveUnusedTensors())
assert len(model.graph.initializer) == 21
assert len(model.graph.value_info) == 43
# model.save("streamlined_cnv.onnx")
......
......@@ -38,7 +38,7 @@ import finn.core.onnx_exec as oxe
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.general import (
RemoveUnusedInitAndValueInfo,
RemoveUnusedTensors,
RemoveStaticGraphInputs,
GiveReadableTensorNames,
GiveUniqueNodeNames,
......@@ -79,9 +79,10 @@ def test_streamline_fc(size, wbits, abits):
expected_ctx = oxe.execute_onnx(model, input_dict, True)
expected = expected_ctx[model.graph.output[0].name]
model = model.transform(Streamline())
model = model.transform(RemoveUnusedInitAndValueInfo())
model = model.transform(RemoveUnusedTensors())
assert len(model.graph.initializer) == 11
assert len(model.graph.value_info) == 21
assert len(model.graph.quantization_annotation) == 18
produced_ctx = oxe.execute_onnx(model, input_dict, True)
produced = produced_ctx[model.graph.output[0].name]
assert np.isclose(expected, produced, atol=1e-3).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment