diff --git a/src/finn/util/visualization.py b/src/finn/util/visualization.py index ede7865eab3bc6ce48f9870c7b4e4a171e2ed133..c3cdfc1e127cb7eba1c65b88e9858430441c937e 100644 --- a/src/finn/util/visualization.py +++ b/src/finn/util/visualization.py @@ -36,7 +36,29 @@ def showSrc(what): print("".join(inspect.getsourcelines(what)[0])) -def showInNetron(model_filename): - netron.start(model_filename, address=("0.0.0.0", 8081)) - localhost_url = os.getenv("LOCALHOST_URL", default="localhost") - return IFrame(src="http://%s:8081/" % localhost_url, width="100%", height=400) +def showInNetron(model_filename: str, localhost_url: str = None, port: int = None): + """Shows a ONNX model file in the Jupyter Notebook using Netron. + + :param model_filename: The path to the ONNX model file. + :type model_filename: str + + :param localhost_url: The IP address used by the Jupyter IFrame to show the model. + Defaults to localhost. + :type localhost_url: str, optional + + :param port: The port number used by Netron and the Jupyter IFrame to show + the ONNX model. Defaults to 8081. + :type port: int, optional + + :return: The IFrame displaying the ONNX model. + :rtype: IPython.lib.display.IFrame + """ + if port is None: + try: + port = int(os.environ["NETRON_PORT"]) + except KeyError: + port = 8081 + if localhost_url is None: + localhost_url = os.getenv("LOCALHOST_URL", default="localhost") + netron.start(model_filename, address=("0.0.0.0", port), browse=False) + return IFrame(src=f"http://{localhost_url}:{port}/", width="100%", height=400)