Skip to content

Commit f2d548d

Browse files
code-dev05mvafinp-wysocki
authored
[PT FE] Added support for aten::hstack and aten::vstack (#28933)
### Details: - *added aten::hstack* - *added aten::vstack* ### Tickets: - Closes #28875 --------- Co-authored-by: Maxim Vafin <[email protected]> Co-authored-by: Przemyslaw Wysocki <[email protected]>
1 parent 9e18e1f commit f2d548d

File tree

4 files changed

+200
-0
lines changed

4 files changed

+200
-0
lines changed

src/frontends/pytorch/src/op/cat.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,28 @@ OutputVector translate_stack_fx(const NodeContext& context) {
155155
return translate_cat_common(context, list_elems, axis, true);
156156
}
157157

158+
OutputVector translate_hstack(const NodeContext& context) {
159+
num_inputs_check(context, 1, 2);
160+
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
161+
int64_t axis = 1;
162+
auto out = translate_cat_common(context, list_elems, axis, false);
163+
if (!context.input_is_none(1)) {
164+
context.mutate_input(1, out[0]);
165+
}
166+
return out;
167+
};
168+
169+
OutputVector translate_vstack(const NodeContext& context) {
170+
num_inputs_check(context, 1, 2);
171+
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
172+
int64_t axis = 0;
173+
auto out = translate_cat_common(context, list_elems, axis, false);
174+
if (!context.input_is_none(1)) {
175+
context.mutate_input(1, out[0]);
176+
}
177+
return out;
178+
};
179+
158180
} // namespace op
159181
} // namespace pytorch
160182
} // namespace frontend

src/frontends/pytorch/src/op_table.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ OP_CONVERTER(translate_group_norm);
124124
OP_CONVERTER(translate_gru);
125125
OP_CONVERTER(translate_hann_window);
126126
OP_CONVERTER(translate_hardtanh);
127+
OP_CONVERTER(translate_hstack);
127128
OP_CONVERTER(translate_if);
128129
OP_CONVERTER(translate_im2col);
129130
OP_CONVERTER(translate_index);
@@ -341,6 +342,7 @@ OP_CONVERTER(translate_quantize_per_channel_fx);
341342
OP_CONVERTER(translate_quantize_per_tensor_fx);
342343
OP_CONVERTER(translate_var_fx);
343344
OP_CONVERTER(translate_var_mean_fx);
345+
OP_CONVERTER(translate_vstack);
344346
OP_CONVERTER(translate_unbind_int_fx);
345347
OP_CONVERTER(translate_zeros_fx);
346348
OP_CONVERTER(translate_zeros_like_fx);
@@ -531,6 +533,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
531533
{"aten::hardsigmoid", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSigmoid>>},
532534
{"aten::hardswish", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
533535
{"aten::hardtanh", op::quantizable_op<op::translate_hardtanh>},
536+
{"aten::hstack", op::translate_hstack},
534537
{"aten::im2col", op::translate_im2col},
535538
{"aten::imag", common_translators::translate_imag},
536539
// aten::index - Supported in limited set of patterns
@@ -737,6 +740,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
737740
{"aten::view_as", op::translate_reshape_as},
738741
{"aten::view_as_complex", op::translate_view_as_complex},
739742
{"aten::view_as_real", op::translate_view_as_real},
743+
{"aten::vstack", op::translate_vstack},
740744
{"aten::wait", op::skip_node},
741745
{"aten::where", op::translate_where},
742746
{"aten::zero", op::translate_zeros_like},
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (C) 2018-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
import torch
6+
import numpy as np
7+
import numpy as np
8+
9+
from pytorch_layer_test_class import PytorchLayerTest
10+
11+
class aten_hstack(torch.nn.Module):
12+
def forward(self, x):
13+
return torch.hstack(self.prepare_input(x))
14+
15+
def prepare_input(self, x):
16+
return (x, x)
17+
18+
class aten_hstack_out(aten_hstack):
19+
def forward(self, x, out):
20+
return torch.hstack(self.prepare_input(x), out=out), out
21+
22+
class TestHstack(PytorchLayerTest):
23+
def _prepare_input(self, out=False, num_repeats=2):
24+
data = np.random.randn(2, 1, 3)
25+
if not out:
26+
return (data, )
27+
concat = [data for _ in range(num_repeats)]
28+
out = np.zeros_like(np.concatenate(concat, axis=1))
29+
return (data, out)
30+
31+
@pytest.mark.nightly
32+
@pytest.mark.precommit
33+
@pytest.mark.parametrize("out", [False, True])
34+
def test_hstack(self, out, ie_device, precision, ir_version):
35+
model = aten_hstack() if not out else aten_hstack_out()
36+
self._test(model, None, "aten::hstack", ie_device,
37+
precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2})
38+
39+
40+
class TestHstackAlignTypes(PytorchLayerTest):
41+
def _prepare_input(self, in_types):
42+
in_vals = []
43+
for i in range(len(in_types)):
44+
dtype = in_types[i]
45+
in_vals.append(np.random.randn(2, 1, 3).astype(dtype))
46+
return in_vals
47+
48+
def create_model(self, in_count):
49+
class aten_align_types_hstack_two_args(torch.nn.Module):
50+
def forward(self, x, y):
51+
ins = [x, y]
52+
return torch.hstack(ins)
53+
54+
class aten_align_types_hstack_three_args(torch.nn.Module):
55+
def forward(self, x, y, z):
56+
ins = [x, y, z]
57+
return torch.hstack(ins)
58+
59+
if in_count == 2:
60+
return aten_align_types_hstack_two_args()
61+
62+
if in_count == 3:
63+
return aten_align_types_hstack_three_args()
64+
65+
@pytest.mark.parametrize(("in_types"), [
66+
(np.float32, np.int32),
67+
(np.int32, np.float32),
68+
(np.float16, np.float32),
69+
(np.int16, np.float16),
70+
(np.int32, np.int64),
71+
# # Three inputs
72+
(np.float32, np.int32, np.int32),
73+
(np.float32, np.int32, np.float32),
74+
(np.int32, np.float32, np.int32),
75+
(np.float32, np.int32, np.int16),
76+
(np.int32, np.float32, np.int16),
77+
(np.int16, np.int32, np.int16),
78+
(np.float16, np.float32, np.float16),
79+
(np.float32, np.float16, np.float32),
80+
(np.float16, np.int32, np.int16),
81+
(np.int16, np.float16, np.int16)
82+
])
83+
@pytest.mark.nightly
84+
@pytest.mark.precommit
85+
def test_align_types_hstack(self, ie_device, precision, ir_version, in_types):
86+
self._test(self.create_model(len(in_types)), None, "aten::hstack",
87+
ie_device, precision, ir_version, kwargs_to_prepare_input={"in_types": in_types})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (C) 2018-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
import torch
6+
import numpy as np
7+
import numpy as np
8+
9+
from pytorch_layer_test_class import PytorchLayerTest
10+
11+
class aten_vstack(torch.nn.Module):
12+
def forward(self, x):
13+
return torch.vstack(self.prepare_input(x))
14+
15+
def prepare_input(self, x):
16+
return (x, x)
17+
18+
class aten_vstack_out(aten_vstack):
19+
def forward(self, x, out):
20+
return torch.vstack(self.prepare_input(x), out=out), out
21+
22+
class TestVstack(PytorchLayerTest):
23+
def _prepare_input(self, out=False, num_repeats=2):
24+
data = np.random.randn(2, 1, 3)
25+
if not out:
26+
return (data, )
27+
concat = [data for _ in range(num_repeats)]
28+
out = np.zeros_like(np.concatenate(concat, axis=0))
29+
return (data, out)
30+
31+
@pytest.mark.nightly
32+
@pytest.mark.precommit
33+
@pytest.mark.parametrize("out", [False, True])
34+
def test_vstack(self, out, ie_device, precision, ir_version):
35+
model = aten_vstack() if not out else aten_vstack_out()
36+
self._test(model, None, "aten::vstack", ie_device,
37+
precision, ir_version, kwargs_to_prepare_input={"out": out, "num_repeats": 2})
38+
39+
40+
class TestVstackAlignTypes(PytorchLayerTest):
41+
def _prepare_input(self, in_types):
42+
in_vals = []
43+
for i in range(len(in_types)):
44+
dtype = in_types[i]
45+
in_vals.append(np.random.randn(2, 1, 3).astype(dtype))
46+
return in_vals
47+
48+
def create_model(self, in_count):
49+
class aten_align_types_vstack_two_args(torch.nn.Module):
50+
def forward(self, x, y):
51+
ins = [x, y]
52+
return torch.vstack(ins)
53+
54+
class aten_align_types_vstack_three_args(torch.nn.Module):
55+
def forward(self, x, y, z):
56+
ins = [x, y, z]
57+
return torch.vstack(ins)
58+
59+
if in_count == 2:
60+
return aten_align_types_vstack_two_args()
61+
62+
if in_count == 3:
63+
return aten_align_types_vstack_three_args()
64+
65+
@pytest.mark.parametrize(("in_types"), [
66+
(np.float32, np.int32),
67+
(np.int32, np.float32),
68+
(np.float16, np.float32),
69+
(np.int16, np.float16),
70+
(np.int32, np.int64),
71+
# # Three inputs
72+
(np.float32, np.int32, np.int32),
73+
(np.float32, np.int32, np.float32),
74+
(np.int32, np.float32, np.int32),
75+
(np.float32, np.int32, np.int16),
76+
(np.int32, np.float32, np.int16),
77+
(np.int16, np.int32, np.int16),
78+
(np.float16, np.float32, np.float16),
79+
(np.float32, np.float16, np.float32),
80+
(np.float16, np.int32, np.int16),
81+
(np.int16, np.float16, np.int16)
82+
])
83+
@pytest.mark.nightly
84+
@pytest.mark.precommit
85+
def test_align_types_vstack(self, ie_device, precision, ir_version, in_types):
86+
self._test(self.create_model(len(in_types)), None, "aten::vstack",
87+
ie_device, precision, ir_version, kwargs_to_prepare_input={"in_types": in_types})

0 commit comments

Comments
 (0)