Skip to content

Commit 325c83e

Browse files
authored
fix: index_put converter to handle multi-shape slicing with None (#3475)
1 parent 291b833 commit 325c83e

File tree

2 files changed

+79
-44
lines changed

2 files changed

+79
-44
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

+73-44
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,6 @@ def index_put_converter(
502502
F = [i for i in range(rank) if indices[i] is None] # Free dimensions
503503
I = [i for i in range(rank) if indices[i] is not None] # Indexed dimensions
504504
K = len(I)
505-
506505
# Determine the maximum size 'N' among the index tensors
507506
if K > 0:
508507
index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None]
@@ -685,16 +684,6 @@ def index_put_converter(
685684
values_reshaped = impl.shuffle.reshape(
686685
ctx, target, source_ir, f"{name}_reshape_scalar", values, (1,)
687686
)
688-
num_dims = len(expected_shape)
689-
ones_shape = tuple([1] * num_dims)
690-
values_reshaped = impl.shuffle.reshape(
691-
ctx,
692-
target,
693-
source_ir,
694-
f"{name}_reshape_to_ones",
695-
values_reshaped,
696-
ones_shape,
697-
)
698687
values_expanded = impl.slice.expand(
699688
ctx,
700689
target,
@@ -705,40 +694,79 @@ def index_put_converter(
705694
)
706695
else: # Non-scalar case
707696
values_shape = list(values.shape)
708-
709-
# Pad dimensions if necessary
710-
if len(values_shape) < len(expected_shape):
711-
values_shape = [1] * (
712-
len(expected_shape) - len(values_shape)
713-
) + values_shape
714-
715-
# Calculate a broadcastable shape
716-
broadcast_shape = []
717-
for exp_dim, val_dim in zip(expected_shape, values_shape):
718-
if val_dim == 1:
719-
broadcast_shape.append(exp_dim)
720-
elif val_dim == exp_dim:
721-
broadcast_shape.append(val_dim)
697+
if K > 0 and N in values_shape:
698+
n_idx = values_shape.index(N)
699+
permute_order = [n_idx] + [
700+
i for i in range(len(values_shape)) if i != n_idx
701+
]
702+
values_permuted = impl.permutation.permute(
703+
ctx, target, source_ir, f"{name}_permute_values", values, permute_order
704+
)
705+
remaining_shape = [
706+
values_shape[i] for i in range(len(values_shape)) if i != n_idx
707+
]
708+
target_f_dims = len(F)
709+
current_f_dims = len(remaining_shape)
710+
if current_f_dims < target_f_dims:
711+
values_expanded_shape = (
712+
[N] + [1] * (target_f_dims - current_f_dims) + remaining_shape
713+
)
722714
else:
723-
raise ValueError(f"Cannot broadcast {values_shape} to {expected_shape}")
724-
725-
# Reshape and then expand
726-
values_reshaped = impl.shuffle.reshape(
727-
ctx,
728-
target,
729-
source_ir,
730-
f"{name}_reshape_values",
731-
values,
732-
tuple(broadcast_shape),
733-
)
734-
values_expanded = impl.slice.expand(
735-
ctx,
736-
target,
737-
source_ir,
738-
f"{name}_expand_values",
739-
values_reshaped,
740-
expected_shape,
741-
)
715+
values_expanded_shape = [N] + remaining_shape[:target_f_dims]
716+
values_expanded = impl.shuffle.reshape(
717+
ctx,
718+
target,
719+
source_ir,
720+
f"{name}_unsqueeze_values",
721+
values_permuted,
722+
tuple(values_expanded_shape),
723+
)
724+
broadcast_shape = []
725+
for exp_dim, val_dim in zip(expected_shape, values_expanded_shape):
726+
if val_dim == 1:
727+
broadcast_shape.append(exp_dim)
728+
elif val_dim == exp_dim:
729+
broadcast_shape.append(val_dim)
730+
else:
731+
raise ValueError(
732+
f"Cannot broadcast {values_expanded_shape} to {expected_shape}"
733+
)
734+
values_expanded = impl.slice.expand(
735+
ctx,
736+
target,
737+
source_ir,
738+
f"{name}_expand_values",
739+
values_expanded,
740+
tuple(broadcast_shape),
741+
)
742+
else:
743+
values_shape_padded = [1] * (
744+
len(expected_shape) - len(values.shape)
745+
) + list(values.shape)
746+
broadcast_shape = []
747+
for exp_dim, val_dim in zip(expected_shape, values_shape_padded):
748+
if val_dim == 1 or exp_dim == val_dim:
749+
broadcast_shape.append(exp_dim)
750+
else:
751+
raise ValueError(
752+
f"Cannot broadcast {values.shape} to {expected_shape}"
753+
)
754+
values_reshaped = impl.shuffle.reshape(
755+
ctx,
756+
target,
757+
source_ir,
758+
f"{name}_reshape_values",
759+
values,
760+
tuple(broadcast_shape),
761+
)
762+
values_expanded = impl.slice.expand(
763+
ctx,
764+
target,
765+
source_ir,
766+
f"{name}_expand_values",
767+
values_reshaped,
768+
expected_shape,
769+
)
742770

743771
# Flatten values to (N * F_volume,)
744772
flattened_values = impl.shuffle.reshape(
@@ -750,6 +778,7 @@ def index_put_converter(
750778
(N * F_volume,),
751779
)
752780

781+
indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32")
753782
# Perform Scatter ND operation
754783
scatter_layer = ctx.net.add_scatter(
755784
input_tensor,

tests/py/dynamo/conversion/test_index_put_aten.py

+6
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ class TestIndexPutConverter(DispatchTestCase):
194194
dtype=torch.int32,
195195
),
196196
),
197+
param(
198+
test_name="4d_indices_none_none_multiple_idx_broadcast_error",
199+
source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32),
200+
indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)),
201+
value_tensor=torch.randn([2, 3, 3], dtype=torch.float32),
202+
),
197203
# param(
198204
# test_name="2d_indices_accumulate_True",
199205
# source_tensor=torch.zeros([5, 5], dtype=torch.int32),

0 commit comments

Comments
 (0)