Skip to content
Snippets Groups Projects
Commit 4932f858 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add ConvertDivToMul transform, call in streamlining

parent 458788fa
No related branches found
No related tags found
No related merge requests found
......@@ -82,7 +82,7 @@ class GiveReadableTensorNames(Transformation):
class ConvertSubToAdd(Transformation):
"""Convert sub nodes to add nodes of appropriate sign."""
"""Convert subtract-a-constant nodes to add-a-constant nodes."""
def apply(self, model):
graph = model.graph
......@@ -94,3 +94,18 @@ class ConvertSubToAdd(Transformation):
model.set_initializer(n.input[1], -A)
# return model_was_changed = False as single iteration is always enough
return (model, False)
class ConvertDivToMul(Transformation):
"""Convert divide by constant nodes to multiply by constant nodes."""
def apply(self, model):
graph = model.graph
for n in graph.node:
if n.op_type == "Div":
A = model.get_initializer(n.input[1])
if A is not None:
n.op_type = "Mul"
model.set_initializer(n.input[1], 1.0 / A)
# return model_was_changed = False as single iteration is always enough
return (model, False)
......@@ -30,6 +30,7 @@ from finn.transformation import Transformation
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.general import (
ConvertSubToAdd,
ConvertDivToMul,
GiveReadableTensorNames,
GiveUniqueNodeNames,
)
......@@ -64,6 +65,7 @@ class Streamline(Transformation):
def apply(self, model):
streamline_transformations = [
ConvertSubToAdd(),
ConvertDivToMul(),
BatchNormToAffine(),
ConvertSignToThres(),
MoveAddPastMul(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment