Skip to content

Commit e7653c7

Browse files
Yinghai Lufacebook-github-bot
Yinghai Lu
authored andcommitted
New chaining/partitioning algorithm for async_scheduling for inference (pytorch#11957)
Summary: Pull Request resolved: pytorch#11957 For distributed inference, we want to use async_scheduling net to run the net as we need its async part. However, according to the profiling, async_net has big overhead of dispatching tasks onto worker threads. This diff improves the issue by generating a smaller number of chains/tasks by grouping the sync ops that can be run in one shot. Note that it also schedule individual async ops as a single chain because unlike gpu ops, rpc ops are not guaranteed to be linearized at the remote site. For example, if you have two rps ops `op1->op2`, op2 won't implicitly block until op1 finishes. Therefore we need to put each of the async op as one chain as async_scheduling net will only sync the tail of the chain. For the all sync op nets, this change give us `1.5X` slower than simple_net, while without the change, it is `7X` slower. Next step is to work on the executor to make the task scheduling faster. And add a fallback path to be able to run ops inline if it's a all-sync net. Reviewed By: ilia-cher Differential Revision: D9874140 fbshipit-source-id: fcd45328698c29211f2c06ee3287194acda12227
1 parent f1f521f commit e7653c7

File tree

4 files changed

+399
-4
lines changed

4 files changed

+399
-4
lines changed

caffe2/core/net_async_base.cc

+10-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ C10_DEFINE_int(
1313

1414
C10_DECLARE_bool(caffe2_dag_net_collect_stats);
1515

16+
C10_DEFINE_bool(
17+
caffe2_net_async_inference_mode,
18+
false,
19+
"If set, use one single chain containing all ops");
20+
1621
C10_DEFINE_bool(
1722
caffe2_net_async_finish_chain,
1823
false,
@@ -73,7 +78,11 @@ AsyncNetBase::AsyncNetBase(
7378
operators_.push_back(op_ptr);
7479
}
7580

76-
execution_chains_ = dag_utils::computeChains(operator_nodes_);
81+
if (c10::FLAGS_caffe2_net_async_inference_mode) {
82+
execution_chains_ = dag_utils::computeGroups(operator_nodes_);
83+
} else {
84+
execution_chains_ = dag_utils::computeChains(operator_nodes_);
85+
}
7786
chains_.reserve(execution_chains_.size());
7887
for (const auto& kv : execution_chains_) {
7988
chains_.push_back(kv.second);

caffe2/core/net_dag_utils.cc

+82
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,88 @@ ExecutionChains computeChains(std::vector<OperatorNode>& orig_nodes) {
278278
return chains;
279279
}
280280

281+
// Here chains are essentially groups, we used chain/group interchangeably
282+
ExecutionChains computeGroups(std::vector<OperatorNode>& orig_nodes) {
283+
const std::vector<OpGraphNode> nodes = pruneOpNodeGraph(orig_nodes);
284+
ExecutionChains chains;
285+
std::vector<int> sync_frontier;
286+
std::vector<int> async_frontier;
287+
288+
std::vector<int> in_degrees;
289+
in_degrees.reserve(nodes.size());
290+
std::transform(
291+
nodes.begin(),
292+
nodes.end(),
293+
std::back_inserter(in_degrees),
294+
[](const OpGraphNode& n) { return n.parents_.size(); });
295+
296+
// Screen out the primary root nodes
297+
for (int idx = 0; idx < (int)nodes.size(); ++idx) {
298+
if (in_degrees[idx] == 0) {
299+
if (orig_nodes[idx].operator_->HasAsyncPart()) {
300+
async_frontier.push_back(idx);
301+
} else {
302+
sync_frontier.push_back(idx);
303+
}
304+
}
305+
}
306+
307+
// We check sync ops on the froniter first and then async ops. This gives us a
308+
// head start to execute sync ops locally while waiting for async ops to
309+
// finish.
310+
std::queue<int> q;
311+
while (!(async_frontier.empty() && sync_frontier.empty())) {
312+
// Sync ops
313+
for (const auto i : sync_frontier) {
314+
q.push(i);
315+
}
316+
sync_frontier.clear();
317+
std::vector<int> chain;
318+
while (!q.empty()) {
319+
int idx = q.front();
320+
q.pop();
321+
chain.push_back(idx);
322+
for (int child : nodes[idx].children_) {
323+
if (--in_degrees[child] == 0) {
324+
if (orig_nodes[child].operator_->HasAsyncPart()) {
325+
async_frontier.push_back(child);
326+
} else {
327+
q.push(child);
328+
}
329+
}
330+
}
331+
}
332+
// add the whole group of continuous sync ops into one chain
333+
if (!chain.empty()) {
334+
chains.emplace(chain.front(), chain);
335+
}
336+
337+
// Async ops
338+
for (const auto i : async_frontier) {
339+
q.push(i);
340+
}
341+
async_frontier.clear();
342+
while (!q.empty()) {
343+
int idx = q.front();
344+
q.pop();
345+
// Put each individual node as a new chain
346+
chains[idx] = {idx};
347+
for (int child : nodes[idx].children_) {
348+
if (--in_degrees[child] == 0) {
349+
if (orig_nodes[child].operator_->HasAsyncPart()) {
350+
q.push(child);
351+
} else {
352+
sync_frontier.push_back(child);
353+
}
354+
}
355+
}
356+
}
357+
}
358+
359+
updateOperatorNodes(orig_nodes, chains);
360+
return chains;
361+
}
362+
281363
ExecutionChains singleChains(std::vector<OperatorNode>& nodes) {
282364
ExecutionChains chains;
283365
for (int i = 0; i < (int)nodes.size(); ++i) {

caffe2/core/net_dag_utils.h

+11-3
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,19 @@ struct OpGraphNode {
4343

4444
using ExecutionChains = std::unordered_map<int, std::vector<int>>;
4545

46-
ExecutionChains computeChains(std::vector<OperatorNode>& orig_nodes);
46+
C10_EXPORT ExecutionChains computeChains(std::vector<OperatorNode>& orig_nodes);
4747

48-
ExecutionChains singleChains(std::vector<OperatorNode>& nodes);
48+
// Instead of breaking down the DAG into chains, we partition it into clusters
49+
// of sync ops and individual async op. This is useful for disturbuted inference
50+
// case where we have sync and async cpu ops. Note that we have go sync each
51+
// aysnc op instead of put them into the chain and sync its tail like GPU op,
52+
// because CPU async ops are typically rpc calls and are not guaranteed to be
53+
// linearized at remote site.
54+
C10_EXPORT ExecutionChains computeGroups(std::vector<OperatorNode>& orig_nodes);
4955

50-
std::vector<OperatorNode> prepareOperatorNodes(
56+
C10_EXPORT ExecutionChains singleChains(std::vector<OperatorNode>& nodes);
57+
58+
C10_EXPORT std::vector<OperatorNode> prepareOperatorNodes(
5159
const std::shared_ptr<const NetDef>& net_def,
5260
Workspace* ws);
5361

0 commit comments

Comments
 (0)