Skip to content

Commit 7652c5d

Browse files
Jingchang Zhangfacebook-github-bot
Jingchang Zhang
authored andcommitted
Reduce unnecessary function call in _set_sharding_context_post_a2a (#2796)
Summary: Pull Request resolved: #2796 This diff aims to improve the performance of the `_set_sharding_context_post_a2a` in `embedding_sharding` module in TorchRec by reducing unnecessary function calls. From different benchmark results, it reduce the function call by over 50% as below: _set_sharding_context_post_a2a_previous-KJT_len:100-keys:10 | Runtime (P90): 0.45 ms _set_sharding_context_post_a2a-KJT_len:100-keys:10 | Runtime (P90): 0.28 ms _set_sharding_context_post_a2a_previous-KJT_len:100-keys:100 | Runtime (P90): 2.47 ms _set_sharding_context_post_a2a-KJT_len:100-keys:100 | Runtime (P90): 0.89 ms _set_sharding_context_post_a2a_previous-KJT_len:1000-keys:10 | Runtime (P90): 4.55 ms _set_sharding_context_post_a2a-KJT_len:1000-keys:10 | Runtime (P90): 2.73 ms _set_sharding_context_post_a2a_previous-KJT_len:1000-keys:100 | Runtime (P90): 24.49 ms _set_sharding_context_post_a2a-KJT_len:1000-keys:100 | Runtime (P90): 8.68 ms _set_sharding_context_post_a2a_previous-KJT_len:10000-keys:10 | Runtime (P90): 46.85 ms _set_sharding_context_post_a2a-KJT_len:10000-keys:10 | Runtime (P90): 28.00 ms _set_sharding_context_post_a2a_previous-KJT_len:10000-keys:100 | Runtime (P90): 243.57 ms _set_sharding_context_post_a2a-KJT_len:10000-keys:100 | Runtime (P90): 89.14 ms Previous trace: about 4-5ms {F1975959472} Now: less than 4ms, some even less than 0.2ms, with much less operations {F1975959630} {F1975959609} Test on IG TAB ESR TTSN model: NE trending is same: {F1975957602} {F1975957605} {F1975957635} Reviewed By: dstaay-fb Differential Revision: D70960056 Privacy Context Container: L1292699 fbshipit-source-id: d5678b6893dbdae63e1f9119e63314c70a808cc0
1 parent fb1f2c4 commit 7652c5d

File tree

4 files changed

+165
-10
lines changed

4 files changed

+165
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
#!/usr/bin/env python3
11+
12+
from typing import Any, List
13+
14+
import click
15+
import torch
16+
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func
17+
from torchrec.distributed.embedding import EmbeddingCollectionContext
18+
from torchrec.distributed.embedding_sharding import _set_sharding_context_post_a2a
19+
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext
20+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
21+
22+
23+
def _set_sharding_context_post_a2a_previous(
24+
kjts: List[KeyedJaggedTensor],
25+
ctx: EmbeddingCollectionContext,
26+
) -> None:
27+
for kjt, sharding_context in zip(kjts, getattr(ctx, "sharding_contexts", [])):
28+
if (
29+
hasattr(sharding_context, "batch_size_per_rank_per_feature")
30+
and kjt.variable_stride_per_key()
31+
and kjt.stride_per_key_per_rank()
32+
):
33+
sharding_context.batch_size_per_rank_per_feature = [
34+
[
35+
kjt.stride_per_key_per_rank()[i][j]
36+
for i in range(len(kjt.stride_per_key_per_rank()))
37+
]
38+
for j in range(len(kjt.stride_per_key_per_rank()[0]))
39+
]
40+
41+
42+
# buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_set_sharding_context_post_a2a -- --num_list=0 --num_keys=0 | grep set_sharding_context_post_a2a
43+
44+
45+
@click.command()
46+
@click.option("--num_list", default=100)
47+
@click.option("--num_keys", default=100)
48+
def main(
49+
num_list: int,
50+
num_keys: int,
51+
) -> None:
52+
if num_list == 0 and num_keys == 0:
53+
for num_list in [100, 1000, 10000]:
54+
for num_keys in [10, 100]:
55+
op_bench(num_list, num_keys, _set_sharding_context_post_a2a_previous)
56+
op_bench(num_list, num_keys, _set_sharding_context_post_a2a)
57+
else:
58+
op_bench(num_list, num_keys, _set_sharding_context_post_a2a_previous)
59+
op_bench(num_list, num_keys, _set_sharding_context_post_a2a)
60+
61+
62+
def op_bench(
63+
num_list: int,
64+
num_keys: int,
65+
func_to_benchmark: Any, # pyre-ignore[2]
66+
) -> None:
67+
kjts = [
68+
KeyedJaggedTensor(
69+
keys=["dummy_id"] * num_keys,
70+
values=torch.IntTensor([1] * num_keys),
71+
lengths=torch.IntTensor([1] * num_keys),
72+
stride_per_key_per_rank=[[1]] * num_keys,
73+
)
74+
for _ in range(num_list)
75+
]
76+
for kjt in kjts:
77+
kjt._variable_stride_per_key = True
78+
ctx = EmbeddingCollectionContext(
79+
sharding_contexts=[
80+
SequenceShardingContext(batch_size_per_rank_per_feature=[])
81+
for _ in range(num_list)
82+
]
83+
)
84+
85+
bench_inputs = []
86+
87+
result = benchmark_func(
88+
name=f"{func_to_benchmark.__name__}-{num_list}-{num_keys}",
89+
bench_inputs=bench_inputs,
90+
prof_inputs=bench_inputs,
91+
num_benchmarks=10,
92+
num_profiles=2,
93+
profile_dir=".",
94+
world_size=1,
95+
func_to_benchmark=func_to_benchmark,
96+
benchmark_func_kwargs={"kjts": kjts, "ctx": ctx},
97+
rank=0,
98+
pre_gpu_load=0,
99+
device_type="cpu",
100+
)
101+
print(result)
102+
103+
104+
if __name__ == "__main__":
105+
main()

torchrec/distributed/benchmark/benchmark_utils.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class BenchmarkResult:
135135

136136
def __str__(self) -> str:
137137
runtime = f"Runtime (P90): {self.runtime_percentile(90):.2f} ms"
138+
if len(self.mem_stats) == 0:
139+
return f"{self.short_name: <{35}} | {runtime}"
138140
mem_alloc = (
139141
f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2f} GB"
140142
)
@@ -749,11 +751,18 @@ def benchmark_func(
749751
func_to_benchmark(bench_inputs, **benchmark_func_kwargs)
750752
end[i].record()
751753
elif device_type == "cpu":
752-
times = timeit.repeat(
753-
lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs),
754-
number=1,
755-
repeat=num_benchmarks,
756-
)
754+
if bench_inputs is None or len(bench_inputs) == 0:
755+
times = timeit.repeat(
756+
lambda: func_to_benchmark(**benchmark_func_kwargs),
757+
number=1,
758+
repeat=num_benchmarks,
759+
)
760+
else:
761+
times = timeit.repeat(
762+
lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs),
763+
number=1,
764+
repeat=num_benchmarks,
765+
)
757766

758767
if device_type == "cuda":
759768
if rank == -1:

torchrec/distributed/embedding_sharding.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -662,12 +662,10 @@ def _set_sharding_context_post_a2a(
662662
and kjt.variable_stride_per_key()
663663
and kjt.stride_per_key_per_rank()
664664
):
665+
strides = kjt.stride_per_key_per_rank()
665666
sharding_context.batch_size_per_rank_per_feature = [
666-
[
667-
kjt.stride_per_key_per_rank()[i][j]
668-
for i in range(len(kjt.stride_per_key_per_rank()))
669-
]
670-
for j in range(len(kjt.stride_per_key_per_rank()[0]))
667+
[strides[i][j] for i in range(len(strides))]
668+
for j in range(len(strides[0]))
671669
]
672670

673671

torchrec/distributed/tests/test_embedding_sharding.py

+43
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from unittest.mock import MagicMock
1515

1616
import hypothesis.strategies as st
17+
import torch
1718

1819
from hypothesis import given, settings
20+
from torchrec.distributed.embedding import EmbeddingCollectionContext
1921

2022
from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel
2123

@@ -24,14 +26,17 @@
2426
_get_grouping_fused_params,
2527
_get_weighted_avg_cache_load_factor,
2628
_prefetch_and_cached,
29+
_set_sharding_context_post_a2a,
2730
group_tables,
2831
)
2932

3033
from torchrec.distributed.embedding_types import (
3134
GroupedEmbeddingConfig,
3235
ShardedEmbeddingTable,
3336
)
37+
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext
3438
from torchrec.modules.embedding_configs import DataType, PoolingType
39+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3540

3641

3742
class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase):
@@ -489,3 +494,41 @@ def test_use_one_tbe_per_table(
489494
_get_table_names_by_groups(tables),
490495
[["table_0", "table_2", "table_4"], ["table_1", "table_1"], ["table_3"]],
491496
)
497+
498+
def test_set_sharding_context_post_a2a(self) -> None:
499+
kjts = [
500+
KeyedJaggedTensor(
501+
keys=["dummy_id", "video_id", "owner_id", "xray_concepts", "dummy_id2"],
502+
values=torch.IntTensor([1] * 10),
503+
lengths=torch.IntTensor([1] * 10),
504+
stride_per_key_per_rank=[
505+
[1, 2],
506+
[1, 2],
507+
[2, 3],
508+
[5, 7],
509+
[3, 4],
510+
],
511+
),
512+
KeyedJaggedTensor(
513+
keys=["dummy_id", "video_id", "owner_id", "xray_concepts", "dummy_id2"],
514+
values=torch.IntTensor([1] * 10),
515+
lengths=torch.IntTensor([1] * 10),
516+
stride_per_key_per_rank=[[3, 1], [5, 2], [7, 3], [1, 2], [6, 8]],
517+
),
518+
]
519+
for kjt in kjts:
520+
kjt._variable_stride_per_key = True
521+
522+
ctx = EmbeddingCollectionContext(
523+
sharding_contexts=[
524+
SequenceShardingContext(batch_size_per_rank_per_feature=[]),
525+
SequenceShardingContext(batch_size_per_rank_per_feature=[]),
526+
]
527+
)
528+
results = [
529+
[[1, 1, 2, 5, 3], [2, 2, 3, 7, 4]],
530+
[[3, 5, 7, 1, 6], [1, 2, 3, 2, 8]],
531+
]
532+
_set_sharding_context_post_a2a(kjts, ctx)
533+
for context, result in zip(ctx.sharding_contexts, results):
534+
self.assertEqual(context.batch_size_per_rank_per_feature, result)

0 commit comments

Comments
 (0)