diff --git a/src/finn/core/modelwrapper.py b/src/finn/core/modelwrapper.py index 98b234592ebe0c704fafd1eed980325d8566e7e2..42acc6fd277c9419920edd534e5660e85c7626b0 100644 --- a/src/finn/core/modelwrapper.py +++ b/src/finn/core/modelwrapper.py @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import copy - +import os import onnx import onnx.helper as oh import onnx.numpy_helper as np_helper @@ -41,6 +41,7 @@ from finn.transformation.general import ( RemoveStaticGraphInputs, SortGraph, ) +from finn.transformation.double_to_single_float import DoubleToSingleFloat class ModelWrapper: @@ -51,10 +52,12 @@ class ModelWrapper: """Creates a ModelWrapper instance. onnx_model_proto can be either a ModelProto instance, or a string with the path to a stored .onnx file on disk, or serialized bytes. - The make_deepcopy option controls whether a deep copy of the ModelProto + + - make_deepcopy : controls whether a deep copy of the ModelProto is made internally. """ if isinstance(onnx_model_proto, str): + assert os.path.isfile(onnx_model_proto) self._model_proto = onnx.load(onnx_model_proto) elif isinstance(onnx_model_proto, bytes): self._model_proto = onnx.load_from_string(onnx_model_proto) @@ -92,15 +95,23 @@ class ModelWrapper: """Runs given anaylsis_fxn on this model and return resulting dict.""" return analysis_fxn(self) - def transform(self, transformation, make_deepcopy=True, cleanup=True): + def transform( + self, transformation, make_deepcopy=True, cleanup=True, fix_float64=True + ): """Applies given Transformation repeatedly until no more changes can be made and returns a transformed ModelWrapper instance. - If make_deepcopy is specified, operates on a new (deep)copy of model. + - make_deepcopy : operates on a new (deep)copy of model. + - fix_float64 : DoubleToSingleFloat correction before starting + - cleanup : execute cleanup transformations before returning """ transformed_model = self if make_deepcopy: transformed_model = copy.deepcopy(self) + if fix_float64: + (transformed_model, model_was_changed) = DoubleToSingleFloat().apply( + transformed_model + ) model_was_changed = True while model_was_changed: (transformed_model, model_was_changed) = transformation.apply(