Skip to content

Commit 0337494

Browse files
Ilia Cherniavskiifacebook-github-bot
Ilia Cherniavskii
authored andcommitted
Reinforce scheduling invariants (pytorch#17132)
Summary: Pull Request resolved: pytorch#17132 schedule() function is not supposed to throw exception and is supposed to succeed in scheduling the full graph of tasks, potential errors (e.g. errors from underlying thread pool, out of memory exceptions etc) are considered not recoverable. The invariant - the graph of tasks is either not executed or executed in full before the call to finishRun() Reviewed By: andrewwdye Differential Revision: D14092457 fbshipit-source-id: a3e5d65dfee5ff5e5e71ec72bb9e576180019698
1 parent 3e44880 commit 0337494

File tree

4 files changed

+147
-124
lines changed

4 files changed

+147
-124
lines changed

caffe2/core/net_async_base.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ void AsyncNetBase::handleChainError(
369369
int task_id,
370370
OperatorBase* op,
371371
const char* err_str,
372-
bool save_exception) {
372+
bool save_exception) noexcept {
373373
std::string err_msg = err_str;
374374
if (op) {
375375
err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown");
@@ -385,7 +385,7 @@ void AsyncNetBase::handleChainError(
385385
}
386386
}
387387

388-
bool AsyncNetBase::run(int task_id, int stream_id) {
388+
bool AsyncNetBase::run(int task_id, int stream_id) noexcept {
389389
OperatorBase* op = nullptr;
390390
try {
391391
// Optionally insert async wait ops,

caffe2/core/net_async_base.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class CAFFE2_API AsyncNetBase : public NetBase {
106106
int task_id,
107107
int stream_id,
108108
const std::vector<int>& wait_task_ids) const;
109-
bool run(int task_id, int stream_id);
109+
bool run(int task_id, int stream_id) noexcept;
110110
int stream(int task_id);
111111
TaskThreadPoolBase* pool(const DeviceOption& device_option);
112112
TaskThreadPoolBase* pool();
@@ -144,7 +144,7 @@ class CAFFE2_API AsyncNetBase : public NetBase {
144144
int task_id,
145145
OperatorBase* op,
146146
const char* err_msg,
147-
bool save_exception = false);
147+
bool save_exception = false) noexcept;
148148
std::atomic<bool> success_;
149149

150150
// Tracing

caffe2/core/net_async_scheduling.cc

+142-119
Original file line numberDiff line numberDiff line change
@@ -32,130 +32,153 @@ bool AsyncSchedulingNet::isInlineTask(int parent_id, int child_id) const {
3232
last_parent_op->device_option(), first_child_op->device_option());
3333
}
3434

35-
void AsyncSchedulingNet::schedule(int task_id, bool run_inline) {
35+
// schedule() is not supposed to throw, all exceptions in the ops are caught
36+
// and reported in the end of the graph's execution, the full graph of tasks
37+
// is expected to be scheduled
38+
void AsyncSchedulingNet::schedule(int task_id, bool run_inline) noexcept {
3639
if (!testAndSetScheduled(task_id)) {
3740
return;
3841
}
3942
auto schedule_func = [this, task_id]() {
40-
if (success_) {
41-
int stream_id = 0;
42-
if (options_.streams_per_gpu_ > 1) {
43-
stream_id = stream(task_id);
44-
}
45-
if (!run(task_id, stream_id)) {
46-
success_ = false;
43+
try {
44+
if (success_) {
45+
int stream_id = 0;
46+
if (options_.streams_per_gpu_ > 1) {
47+
try {
48+
stream_id = stream(task_id);
49+
} catch (const std::exception& e) {
50+
C10_LOG_EVERY_MS(ERROR, 1000)
51+
<< "Failed to select a stream: " << e.what();
52+
}
53+
}
54+
if (!run(task_id, stream_id)) {
55+
success_ = false;
56+
}
4757
}
48-
}
4958

50-
if (options_.report_stats_) {
51-
auto last_op_id = lastTaskOpId(task_id);
52-
auto* last_op = lastTaskOp(task_id);
53-
if (last_op->device_option().device_type() == PROTO_CPU &&
54-
last_op->HasAsyncPart()) {
55-
last_op->event().SetCallback(
56-
[this, last_op_id] { counters_.AddPerOpAsyncEndTime(last_op_id); });
59+
if (options_.report_stats_) {
60+
try {
61+
auto last_op_id = lastTaskOpId(task_id);
62+
auto* last_op = lastTaskOp(task_id);
63+
if (last_op->device_option().device_type() == PROTO_CPU &&
64+
last_op->HasAsyncPart()) {
65+
last_op->event().SetCallback([this, last_op_id] {
66+
counters_.AddPerOpAsyncEndTime(last_op_id);
67+
});
68+
}
69+
} catch (const std::exception& e) {
70+
C10_LOG_EVERY_MS(ERROR, 1000)
71+
<< "Failed to report operator stats: " << e.what();
72+
}
5773
}
58-
}
5974

60-
for (auto child_id : children(task_id)) {
61-
int parent_count = updateParentCount(child_id);
62-
if (parent_count == 0) {
63-
// Schedule a child if:
64-
// - there is failure, we skip an op execution and finish the job
65-
// - forced scheduling though always_schedule_child_
66-
// - finish_chain_ is set, in this case parents are
67-
// guaranteed to be finished
68-
// - in all other cases, check parents with canSchedule
69-
if (!success_ || options_.always_schedule_child_ ||
70-
options_.finish_chain_ || canSchedule(child_id)) {
71-
// if DFS scheduling is enabled, run children inline,
72-
// ignore DFS scheduling in callbacks
73-
schedule(child_id, isInlineTask(task_id, child_id));
74-
} else {
75-
bool parent_failed = false;
76-
bool parent_needs_polling = false;
77-
std::vector<int> parents_with_callback;
75+
for (auto child_id : children(task_id)) {
76+
int parent_count = updateParentCount(child_id);
77+
if (parent_count == 0) {
78+
// Schedule a child if:
79+
// - there is failure, we skip an op execution and finish the job
80+
// - forced scheduling though always_schedule_child_
81+
// - finish_chain_ is set, in this case parents are
82+
// guaranteed to be finished
83+
// - in all other cases, check parents with canSchedule
84+
if (!success_ || options_.always_schedule_child_ ||
85+
options_.finish_chain_ || canSchedule(child_id)) {
86+
// if DFS scheduling is enabled, run children inline,
87+
// ignore DFS scheduling in callbacks
88+
schedule(child_id, isInlineTask(task_id, child_id));
89+
} else {
90+
bool parent_failed = false;
91+
bool parent_needs_polling = false;
92+
std::vector<int> parents_with_callback;
7893

79-
for (auto parent_id : parents(child_id)) {
80-
auto& parent_event = event(parent_id);
81-
auto parent_status = parent_event.Query();
94+
for (auto parent_id : parents(child_id)) {
95+
auto& parent_event = event(parent_id);
96+
auto parent_status = parent_event.Query();
8297

83-
if (parent_status == EventStatus::EVENT_FAILED) {
84-
parent_failed = true;
85-
break;
86-
} else if (parent_status == EventStatus::EVENT_SCHEDULED) {
87-
// parent is not finished yet, check if this is blocking us
88-
// from scheduling a child
89-
if (!canSchedule(parent_id, child_id)) {
90-
// we can't schedule a child because of this parent,
91-
// check if parent supports callback
92-
if (parent_event.SupportsCallback()) {
93-
parents_with_callback.push_back(parent_id);
94-
} else {
95-
parent_needs_polling = true;
96-
break;
98+
if (parent_status == EventStatus::EVENT_FAILED) {
99+
parent_failed = true;
100+
break;
101+
} else if (parent_status == EventStatus::EVENT_SCHEDULED) {
102+
// parent is not finished yet, check if this is blocking us
103+
// from scheduling a child
104+
if (!canSchedule(parent_id, child_id)) {
105+
// we can't schedule a child because of this parent,
106+
// check if parent supports callback
107+
if (parent_event.SupportsCallback()) {
108+
parents_with_callback.push_back(parent_id);
109+
} else {
110+
parent_needs_polling = true;
111+
break;
112+
}
97113
}
114+
} else if (parent_status != EventStatus::EVENT_SUCCESS) {
115+
VLOG(1) << "Unexpected parent task state: " << parent_status
116+
<< ", task id: " << child_id
117+
<< ", parent task id: " << parent_id;
118+
parent_failed = true;
119+
break;
98120
}
99-
} else if (parent_status != EventStatus::EVENT_SUCCESS) {
100-
VLOG(1) << "Unexpected parent task state: " << parent_status
101-
<< ", task id: " << child_id
102-
<< ", parent task id: " << parent_id;
103-
parent_failed = true;
104-
break;
105121
}
106-
}
107122

108-
if (parent_failed) {
109-
// one of parents failed, set failure flag and wrap up execution
110-
success_ = false;
111-
schedule(child_id, isInlineTask(task_id, child_id));
112-
} else if (parent_needs_polling) {
113-
// some parents are blocking us from scheduling a child and don't
114-
// support callbacks, using polling
115-
const auto& child_device_option = event(child_id).GetDeviceOption();
116-
pool(child_device_option)
117-
->run(std::bind(
118-
&AsyncSchedulingNet::pollAndSchedule, this, child_id));
119-
} else if (!parents_with_callback.empty()) {
120-
// some parents are blocking us from scheduling a child and they
121-
// support callbacks
122-
for (auto parent_id : parents_with_callback) {
123-
event(parent_id).SetCallback(std::bind(
124-
&AsyncSchedulingNet::parentCallback, this, parent_id));
123+
if (parent_failed) {
124+
// one of parents failed, set failure flag and wrap up execution
125+
success_ = false;
126+
schedule(child_id, isInlineTask(task_id, child_id));
127+
} else if (parent_needs_polling) {
128+
// some parents are blocking us from scheduling a child and don't
129+
// support callbacks, using polling
130+
const auto& child_device_option =
131+
event(child_id).GetDeviceOption();
132+
pool(child_device_option)
133+
->run(std::bind(
134+
&AsyncSchedulingNet::pollAndSchedule, this, child_id));
135+
} else if (!parents_with_callback.empty()) {
136+
// some parents are blocking us from scheduling a child and they
137+
// support callbacks
138+
for (auto parent_id : parents_with_callback) {
139+
event(parent_id).SetCallback(std::bind(
140+
&AsyncSchedulingNet::parentCallback, this, parent_id));
141+
}
142+
} else {
143+
// we're ready to schedule a child
144+
schedule(child_id, isInlineTask(task_id, child_id));
125145
}
126-
} else {
127-
// we're ready to schedule a child
128-
schedule(child_id, isInlineTask(task_id, child_id));
129146
}
130147
}
131148
}
132-
}
133149

134-
// In case of net's failure, make sure all pending tasks are finished
135-
if (!success_) {
136-
// Simple logic to capture all pending tasks - check all tasks
137-
// at the end of each task in case of net's failure
138-
for (auto tid = 0; tid < tasksNum(); ++tid) {
139-
if (event(tid).Query() == EventStatus::EVENT_SCHEDULED) {
140-
// SetFinished may throw, e.g. when we call it on already finished
141-
// event, and in some other cases (CUDA)
142-
try {
143-
event(tid).SetFinished("Cancelled");
144-
} catch (const EnforceNotMet&) {
145-
// ignore
150+
// In case of net's failure, make sure all pending tasks are finished
151+
if (!success_) {
152+
// Simple logic to capture all pending tasks - check all tasks
153+
// at the end of each task in case of net's failure
154+
for (auto tid = 0; tid < tasksNum(); ++tid) {
155+
if (event(tid).Query() == EventStatus::EVENT_SCHEDULED) {
156+
// SetFinished may throw, e.g. when we call it on already finished
157+
// event, and in some other cases (CUDA)
158+
try {
159+
event(tid).SetFinished("Cancelled");
160+
} catch (const EnforceNotMet&) {
161+
// ignore
162+
}
146163
}
147164
}
148165
}
149-
}
150166

151-
// finishRun may cause waiters to wake up and destroy the net,
152-
// before we call finishRun we need to make sure all other (finishing)
153-
// tasks are done;
154-
// Bumping and checking the counter after the task's job is done
155-
auto tasks_num = tasksNum();
156-
auto cur_processed_tasks = ++processed_tasks_num_;
157-
if (cur_processed_tasks == tasks_num) {
158-
finishRun();
167+
// finishRun may cause waiters to wake up and destroy the net,
168+
// before we call finishRun we need to make sure all other (finishing)
169+
// tasks are done;
170+
// Bumping and checking the counter after the task's job is done
171+
auto tasks_num = tasksNum();
172+
auto cur_processed_tasks = ++processed_tasks_num_;
173+
if (cur_processed_tasks == tasks_num) {
174+
finishRun();
175+
}
176+
} catch (const std::exception& e) {
177+
// error of core scheduling and/or logic, will call terminate
178+
LOG(FATAL) << "Unexpected error during graph scheduling run: "
179+
<< e.what();
180+
} catch (...) {
181+
LOG(FATAL) << "Unknown error during graph scheduling run";
159182
}
160183
};
161184

@@ -215,33 +238,33 @@ void AsyncSchedulingNet::finishRun() {
215238

216239
bool AsyncSchedulingNet::RunAsync() {
217240
try {
218-
{
219-
std::unique_lock<std::mutex> lock(running_mutex_);
220-
if (running_) {
221-
LOG(ERROR) << "Detected concurrent runs";
222-
return false;
223-
}
224-
running_ = true;
225-
reset();
226-
227-
StartAllObservers();
228-
tracing::startIter(tracer_);
229-
if (options_.report_stats_) {
230-
counters_.ReportRunStart();
231-
}
241+
std::unique_lock<std::mutex> lock(running_mutex_);
242+
if (running_) {
243+
LOG(ERROR) << "Detected concurrent runs";
244+
return false;
232245
}
246+
running_ = true;
247+
reset();
233248

234-
for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
235-
if (parents(task_id).empty()) {
236-
schedule(task_id, options_.run_root_tasks_inline_);
237-
}
249+
StartAllObservers();
250+
tracing::startIter(tracer_);
251+
if (options_.report_stats_) {
252+
counters_.ReportRunStart();
238253
}
239254
} catch (const std::exception& e) {
240255
LOG(ERROR) << "Exception while starting an async run: " << e.what();
241256
finishRun();
242257
return false;
243258
}
244259

260+
// schedule() is not expected to throw, at this moment all the initial tasks
261+
// will be scheduled and the full graph of tasks will be executed
262+
for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
263+
if (parents(task_id).empty()) {
264+
schedule(task_id, options_.run_root_tasks_inline_);
265+
}
266+
}
267+
245268
if (tasksNum() == 0) {
246269
finishRun();
247270
}

caffe2/core/net_async_scheduling.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class CAFFE2_API AsyncSchedulingNet : public AsyncNetBase {
1818
bool RunAsync() override;
1919

2020
void pollAndSchedule(int task_id);
21-
void schedule(int task_id, bool run_inline = false);
21+
void schedule(int task_id, bool run_inline = false) noexcept;
2222
void reset() override;
2323
virtual void finishRun();
2424
void parentCallback(int parent_id);

0 commit comments

Comments
 (0)