Skip to content

Commit 479b826

Browse files
zrphercule2facebook-github-bot
authored andcommitted
Back out "[pytorch][PR] Support upsample" (pytorch#13413)
Summary: Pull Request resolved: pytorch#13413 Original commit changeset: d5db200365f1 Reviewed By: houseroad Differential Revision: D12870356 fbshipit-source-id: be115d2370636786901c822895664ccace2a9bc2
1 parent a477886 commit 479b826

File tree

3 files changed

+15
-23
lines changed

3 files changed

+15
-23
lines changed

test/onnx/expect/TestOperators.test_upsample.expect

+10-16
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,23 @@ ir_version: 3
22
producer_name: "pytorch"
33
producer_version: "0.4"
44
graph {
5-
node {
6-
output: "1"
7-
op_type: "Constant"
8-
attribute {
9-
name: "value"
10-
t {
11-
dims: 4
12-
data_type: FLOAT
13-
raw_data: "\000\000\200?\000\000\200?\000\000\000@\000\000\000@"
14-
}
15-
type: TENSOR
16-
}
17-
}
185
node {
196
input: "0"
20-
input: "1"
21-
output: "2"
7+
output: "1"
228
op_type: "Upsample"
239
attribute {
2410
name: "mode"
2511
s: "linear"
2612
type: STRING
2713
}
14+
attribute {
15+
name: "scales"
16+
floats: 1
17+
floats: 1
18+
floats: 2
19+
floats: 2
20+
type: FLOATS
21+
}
2822
}
2923
name: "torch-jit-export"
3024
input {
@@ -50,7 +44,7 @@ graph {
5044
}
5145
}
5246
output {
53-
name: "2"
47+
name: "1"
5448
type {
5549
tensor_type {
5650
elem_type: FLOAT

test/onnx/test_operators.py

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def assertONNX(self, f, args, params=None, **kwargs):
7575
import test_onnx_common
7676
model_def = onnx.ModelProto.FromString(onnx_model_pb)
7777
onnx.checker.check_model(model_def)
78+
7879
if _onnx_test:
7980
test_function = inspect.stack()[1][0].f_code.co_name
8081
test_name = test_function[0:4] + "_operator" + test_function[4:]

torch/onnx/symbolic.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,8 @@ def replication_pad(g, input, padding):
642642
def upsample_nearest2d(g, input, output_size):
643643
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
644644
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
645-
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
646-
width_scale]))
647-
648-
return g.op("Upsample", input, scales,
645+
return g.op("Upsample", input,
646+
scales_f=[1., 1., height_scale, width_scale],
649647
mode_s="nearest")
650648

651649

@@ -655,9 +653,8 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
655653
return _unimplemented("upsample_bilinear2d", "align_corners == True")
656654
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
657655
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
658-
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
659-
width_scale]))
660-
return g.op("Upsample", input, scales,
656+
return g.op("Upsample", input,
657+
scales_f=[1., 1., height_scale, width_scale],
661658
mode_s="linear")
662659

663660

0 commit comments

Comments
 (0)