From 985d21d70f61294bf607acdd16ffaf39ec2e2206 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Fri, 11 Sep 2020 11:29:19 +0200
Subject: [PATCH] [Core] execute DoubleToSingleFloat prior to transforms

---
 src/finn/core/modelwrapper.py | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py
index 98b234592..42acc6fd2 100644
--- a/src/finn/core/modelwrapper.py
+++ b/src/finn/core/modelwrapper.py
@@ -27,7 +27,7 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import copy
-
+import os
 import onnx
 import onnx.helper as oh
 import onnx.numpy_helper as np_helper
@@ -41,6 +41,7 @@ from finn.transformation.general import (
     RemoveStaticGraphInputs,
     SortGraph,
 )
+from finn.transformation.double_to_single_float import DoubleToSingleFloat
 
 
 class ModelWrapper:
@@ -51,10 +52,12 @@ class ModelWrapper:
         """Creates a ModelWrapper instance.
         onnx_model_proto can be either a ModelProto instance, or a string
         with the path to a stored .onnx file on disk, or serialized bytes.
-        The make_deepcopy option controls whether a deep copy of the ModelProto
+
+        - make_deepcopy : controls whether a deep copy of the ModelProto
         is made internally.
         """
         if isinstance(onnx_model_proto, str):
+            assert os.path.isfile(onnx_model_proto)
             self._model_proto = onnx.load(onnx_model_proto)
         elif isinstance(onnx_model_proto, bytes):
             self._model_proto = onnx.load_from_string(onnx_model_proto)
@@ -92,15 +95,23 @@ class ModelWrapper:
         """Runs given anaylsis_fxn on this model and return resulting dict."""
         return analysis_fxn(self)
 
-    def transform(self, transformation, make_deepcopy=True, cleanup=True):
+    def transform(
+        self, transformation, make_deepcopy=True, cleanup=True, fix_float64=True
+    ):
         """Applies given Transformation repeatedly until no more changes can be made
         and returns a transformed ModelWrapper instance.
 
-        If make_deepcopy is specified, operates on a new (deep)copy of model.
+        - make_deepcopy : operates on a new (deep)copy of model.
+        - fix_float64 : DoubleToSingleFloat correction before starting
+        - cleanup : execute cleanup transformations before returning
         """
         transformed_model = self
         if make_deepcopy:
             transformed_model = copy.deepcopy(self)
+        if fix_float64:
+            (transformed_model, model_was_changed) = DoubleToSingleFloat().apply(
+                transformed_model
+            )
         model_was_changed = True
         while model_was_changed:
             (transformed_model, model_was_changed) = transformation.apply(
-- 
GitLab