diff --git a/src/finn/util/visualization.py b/src/finn/util/visualization.py index ede7865eab3bc6ce48f9870c7b4e4a171e2ed133..397bebb64c21e9c5a1cb09d2a01fe3e10502b558 100644 --- a/src/finn/util/visualization.py +++ b/src/finn/util/visualization.py @@ -36,7 +36,27 @@ def showSrc(what): print("".join(inspect.getsourcelines(what)[0])) -def showInNetron(model_filename): - netron.start(model_filename, address=("0.0.0.0", 8081)) - localhost_url = os.getenv("LOCALHOST_URL", default="localhost") - return IFrame(src="http://%s:8081/" % localhost_url, width="100%", height=400) +def showInNetron(model_filename: str, localhost_url: str = None, port: int = None): + """Shows a ONNX model file in the Jupyter Notebook using Netron. + + :param model_filename: The path to the ONNX model file. + :type model_filename: str + + :param localhost_url: The IP address used by the Jupyter IFrame to show the model. + Defaults to localhost. + :type localhost_url: str, optional + + :param port: The port number used by Netron and the Jupyter IFrame to show + the ONNX model. Defaults to 8081. + :type port: int, optional + + :return: The IFrame displaying the ONNX model. + :rtype: IPython.lib.display.IFrame + """ + try: + port = port or int(os.getenv("NETRON_PORT", default="8081")) + except ValueError: + port = 8081 + localhost_url = localhost_url or os.getenv("LOCALHOST_URL", default="localhost") + netron.start(model_filename, address=("0.0.0.0", port), browse=False) + return IFrame(src=f"http://{localhost_url}:{port}/", width="100%", height=400)