Skip to content
Snippets Groups Projects
Commit 0136b4d2 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] bring back conv for batchnorm2affine test

parent 98dc81dc
No related branches found
No related tags found
No related merge requests found
......@@ -58,18 +58,17 @@ def test_batchnorm_to_affine_lfc_w1a1():
os.remove(export_onnx_path)
# cnv batchnorm to affine not yet supported
# def test_batchnorm_to_affine_cnv_w1a1():
# lfc = get_test_model_trained("CNV", 1, 1)
# bo.export_finn_onnx(lfc, (1, 3, 32, 32), export_onnx_path)
# model = ModelWrapper(export_onnx_path)
# model = model.transform(InferShapes())
# model = model.transform(FoldConstants())
# # TODO shape inference failing on transformed model below -- needs debug
# new_model = model.transform(BatchNormToAffine())
# # check that there are no BN nodes left
# # TODO replace this with execution test
# op_types = list(map(lambda x: x.op_type, new_model.graph.node))
# assert "BatchNormalization" not in op_types
# os.remove(export_onnx_path)
def test_batchnorm_to_affine_cnv_w1a1():
lfc = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(lfc, (1, 3, 32, 32), export_onnx_path)
model = ModelWrapper(export_onnx_path)
model = model.transform(InferShapes())
model = model.transform(FoldConstants())
model.save("old.onnx")
# TODO shape inference failing on transformed model below -- needs debug
new_model = model.transform(BatchNormToAffine())
# check that there are no BN nodes left
# TODO replace this with execution test
op_types = list(map(lambda x: x.op_type, new_model.graph.node))
assert "BatchNormalization" not in op_types
# os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment