diff --git a/tests/custom_op/test_im2col.py b/tests/custom_op/test_im2col.py index 946653d671043c9272305588d4304e06b69da3c4..0b148145bd6f7d9819e4b72da5333d2556de94e4 100644 --- a/tests/custom_op/test_im2col.py +++ b/tests/custom_op/test_im2col.py @@ -46,7 +46,7 @@ def execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim, pad_amt=0, pad_val=0): ) graph = helper.make_graph( - nodes=[Im2Col_node], name="im2col_graph", inputs=[inp], outputs=[outp], + nodes=[Im2Col_node], name="im2col_graph", inputs=[inp], outputs=[outp] ) model = helper.make_model(graph, producer_name="im2col-model") @@ -175,19 +175,19 @@ def test_im2col(): [ [ [ - [1.0, 2.0, 5.0, 6.0, -1.0, -2.0, -5.0, -6.0], - [2.0, 3.0, 6.0, 7.0, -2.0, -3.0, -6.0, -7.0], - [3.0, 4.0, 7.0, 8.0, -3.0, -4.0, -7.0, -8.0], + [1.0, -1.0, 2.0, -2.0, 5.0, -5.0, 6.0, -6.0], + [2.0, -2.0, 3.0, -3.0, 6.0, -6.0, 7.0, -7.0], + [3.0, -3.0, 4.0, -4.0, 7.0, -7.0, 8.0, -8.0], ], [ - [5.0, 6.0, 9.0, 10.0, -5.0, -6.0, -9.0, -10.0], - [6.0, 7.0, 10.0, 11.0, -6.0, -7.0, -10.0, -11.0], - [7.0, 8.0, 11.0, 12.0, -7.0, -8.0, -11.0, -12.0], + [5.0, -5.0, 6.0, -6.0, 9.0, -9.0, 10.0, -10.0], + [6.0, -6.0, 7.0, -7.0, 10.0, -10.0, 11.0, -11.0], + [7.0, -7.0, 8.0, -8.0, 11.0, -11.0, 12.0, -12.0], ], [ - [9.0, 10.0, 13.0, 14.0, -9.0, -10.0, -13.0, -14.0], - [10.0, 11.0, 14.0, 15.0, -10.0, -11.0, -14.0, -15.0], - [11.0, 12.0, 15.0, 16.0, -11.0, -12.0, -15.0, -16.0], + [9.0, -9.0, 10.0, -10.0, 13.0, -13.0, 14.0, -14.0], + [10.0, -10.0, 11.0, -11.0, 14.0, -14.0, 15.0, -15.0], + [11.0, -11.0, 12.0, -12.0, 15.0, -15.0, 16.0, -16.0], ], ] ], @@ -222,39 +222,39 @@ def test_im2col(): [ [ [ - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0], - [0.0, 0.0, 1.0, 2.0, 0.0, 0.0, -1.0, -2.0], - [0.0, 0.0, 2.0, 3.0, 0.0, 0.0, -2.0, -3.0], - [0.0, 0.0, 3.0, 4.0, 0.0, 0.0, -3.0, -4.0], - [0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -4.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0], + [0.0, 0.0, 0.0, 0.0, 1.0, -1.0, 2.0, -2.0], + [0.0, 0.0, 0.0, 0.0, 2.0, -2.0, 3.0, -3.0], + [0.0, 0.0, 0.0, 0.0, 3.0, -3.0, 4.0, -4.0], + [0.0, 0.0, 0.0, 0.0, 4.0, -4.0, 0.0, 0.0], ], [ - [0.0, 1.0, 0.0, 5.0, 0.0, -1.0, 0.0, -5.0], - [1.0, 2.0, 5.0, 6.0, -1.0, -2.0, -5.0, -6.0], - [2.0, 3.0, 6.0, 7.0, -2.0, -3.0, -6.0, -7.0], - [3.0, 4.0, 7.0, 8.0, -3.0, -4.0, -7.0, -8.0], - [4.0, 0.0, 8.0, 0.0, -4.0, 0.0, -8.0, 0.0], + [0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 5.0, -5.0], + [1.0, -1.0, 2.0, -2.0, 5.0, -5.0, 6.0, -6.0], + [2.0, -2.0, 3.0, -3.0, 6.0, -6.0, 7.0, -7.0], + [3.0, -3.0, 4.0, -4.0, 7.0, -7.0, 8.0, -8.0], + [4.0, -4.0, 0.0, 0.0, 8.0, -8.0, 0.0, 0.0], ], [ - [0.0, 5.0, 0.0, 9.0, 0.0, -5.0, 0.0, -9.0], - [5.0, 6.0, 9.0, 10.0, -5.0, -6.0, -9.0, -10.0], - [6.0, 7.0, 10.0, 11.0, -6.0, -7.0, -10.0, -11.0], - [7.0, 8.0, 11.0, 12.0, -7.0, -8.0, -11.0, -12.0], - [8.0, 0.0, 12.0, 0.0, -8.0, 0.0, -12.0, 0.0], + [0.0, 0.0, 5.0, -5.0, 0.0, 0.0, 9.0, -9.0], + [5.0, -5.0, 6.0, -6.0, 9.0, -9.0, 10.0, -10.0], + [6.0, -6.0, 7.0, -7.0, 10.0, -10.0, 11.0, -11.0], + [7.0, -7.0, 8.0, -8.0, 11.0, -11.0, 12.0, -12.0], + [8.0, -8.0, 0.0, 0.0, 12.0, -12.0, 0.0, 0.0], ], [ - [0.0, 9.0, 0.0, 13.0, 0.0, -9.0, 0.0, -13.0], - [9.0, 10.0, 13.0, 14.0, -9.0, -10.0, -13.0, -14.0], - [10.0, 11.0, 14.0, 15.0, -10.0, -11.0, -14.0, -15.0], - [11.0, 12.0, 15.0, 16.0, -11.0, -12.0, -15.0, -16.0], - [12.0, 0.0, 16.0, 0.0, -12.0, 0.0, -16.0, 0.0], + [0.0, 0.0, 9.0, -9.0, 0.0, 0.0, 13.0, -13.0], + [9.0, -9.0, 10.0, -10.0, 13.0, -13.0, 14.0, -14.0], + [10.0, -10.0, 11.0, -11.0, 14.0, -14.0, 15.0, -15.0], + [11.0, -11.0, 12.0, -12.0, 15.0, -15.0, 16.0, -16.0], + [12.0, -12.0, 0.0, 0.0, 16.0, -16.0, 0.0, 0.0], ], [ - [0.0, 13.0, 0.0, 0.0, 0.0, -13.0, 0.0, 0.0], - [13.0, 14.0, 0.0, 0.0, -13.0, -14.0, 0.0, 0.0], - [14.0, 15.0, 0.0, 0.0, -14.0, -15.0, 0.0, 0.0], - [15.0, 16.0, 0.0, 0.0, -15.0, -16.0, 0.0, 0.0], - [16.0, 0.0, 0.0, 0.0, -16.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 13.0, -13.0, 0.0, 0.0, 0.0, 0.0], + [13.0, -13.0, 14.0, -14.0, 0.0, 0.0, 0.0, 0.0], + [14.0, -14.0, 15.0, -15.0, 0.0, 0.0, 0.0, 0.0], + [15.0, -15.0, 16.0, -16.0, 0.0, 0.0, 0.0, 0.0], + [16.0, -16.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], ] ], @@ -281,7 +281,7 @@ def test_im2col_infer_shapes(): "outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch] ) - abs_node = helper.make_node("Abs", inputs=["inp"], outputs=["abs"],) + abs_node = helper.make_node("Abs", inputs=["inp"], outputs=["abs"]) Im2Col_node = helper.make_node( "Im2Col", @@ -293,7 +293,7 @@ def test_im2col_infer_shapes(): input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch), ) - abs1_node = helper.make_node("Abs", inputs=["im2col"], outputs=["outp"],) + abs1_node = helper.make_node("Abs", inputs=["im2col"], outputs=["outp"]) graph = helper.make_graph( nodes=[abs_node, Im2Col_node, abs1_node],