@@ -32,130 +32,153 @@ bool AsyncSchedulingNet::isInlineTask(int parent_id, int child_id) const {
32
32
last_parent_op->device_option (), first_child_op->device_option ());
33
33
}
34
34
35
- void AsyncSchedulingNet::schedule (int task_id, bool run_inline) {
35
+ // schedule() is not supposed to throw, all exceptions in the ops are caught
36
+ // and reported in the end of the graph's execution, the full graph of tasks
37
+ // is expected to be scheduled
38
+ void AsyncSchedulingNet::schedule (int task_id, bool run_inline) noexcept {
36
39
if (!testAndSetScheduled (task_id)) {
37
40
return ;
38
41
}
39
42
auto schedule_func = [this , task_id]() {
40
- if (success_) {
41
- int stream_id = 0 ;
42
- if (options_.streams_per_gpu_ > 1 ) {
43
- stream_id = stream (task_id);
44
- }
45
- if (!run (task_id, stream_id)) {
46
- success_ = false ;
43
+ try {
44
+ if (success_) {
45
+ int stream_id = 0 ;
46
+ if (options_.streams_per_gpu_ > 1 ) {
47
+ try {
48
+ stream_id = stream (task_id);
49
+ } catch (const std::exception & e) {
50
+ C10_LOG_EVERY_MS (ERROR, 1000 )
51
+ << " Failed to select a stream: " << e.what ();
52
+ }
53
+ }
54
+ if (!run (task_id, stream_id)) {
55
+ success_ = false ;
56
+ }
47
57
}
48
- }
49
58
50
- if (options_.report_stats_ ) {
51
- auto last_op_id = lastTaskOpId (task_id);
52
- auto * last_op = lastTaskOp (task_id);
53
- if (last_op->device_option ().device_type () == PROTO_CPU &&
54
- last_op->HasAsyncPart ()) {
55
- last_op->event ().SetCallback (
56
- [this , last_op_id] { counters_.AddPerOpAsyncEndTime (last_op_id); });
59
+ if (options_.report_stats_ ) {
60
+ try {
61
+ auto last_op_id = lastTaskOpId (task_id);
62
+ auto * last_op = lastTaskOp (task_id);
63
+ if (last_op->device_option ().device_type () == PROTO_CPU &&
64
+ last_op->HasAsyncPart ()) {
65
+ last_op->event ().SetCallback ([this , last_op_id] {
66
+ counters_.AddPerOpAsyncEndTime (last_op_id);
67
+ });
68
+ }
69
+ } catch (const std::exception & e) {
70
+ C10_LOG_EVERY_MS (ERROR, 1000 )
71
+ << " Failed to report operator stats: " << e.what ();
72
+ }
57
73
}
58
- }
59
74
60
- for (auto child_id : children (task_id)) {
61
- int parent_count = updateParentCount (child_id);
62
- if (parent_count == 0 ) {
63
- // Schedule a child if:
64
- // - there is failure, we skip an op execution and finish the job
65
- // - forced scheduling though always_schedule_child_
66
- // - finish_chain_ is set, in this case parents are
67
- // guaranteed to be finished
68
- // - in all other cases, check parents with canSchedule
69
- if (!success_ || options_.always_schedule_child_ ||
70
- options_.finish_chain_ || canSchedule (child_id)) {
71
- // if DFS scheduling is enabled, run children inline,
72
- // ignore DFS scheduling in callbacks
73
- schedule (child_id, isInlineTask (task_id, child_id));
74
- } else {
75
- bool parent_failed = false ;
76
- bool parent_needs_polling = false ;
77
- std::vector<int > parents_with_callback;
75
+ for (auto child_id : children (task_id)) {
76
+ int parent_count = updateParentCount (child_id);
77
+ if (parent_count == 0 ) {
78
+ // Schedule a child if:
79
+ // - there is failure, we skip an op execution and finish the job
80
+ // - forced scheduling though always_schedule_child_
81
+ // - finish_chain_ is set, in this case parents are
82
+ // guaranteed to be finished
83
+ // - in all other cases, check parents with canSchedule
84
+ if (!success_ || options_.always_schedule_child_ ||
85
+ options_.finish_chain_ || canSchedule (child_id)) {
86
+ // if DFS scheduling is enabled, run children inline,
87
+ // ignore DFS scheduling in callbacks
88
+ schedule (child_id, isInlineTask (task_id, child_id));
89
+ } else {
90
+ bool parent_failed = false ;
91
+ bool parent_needs_polling = false ;
92
+ std::vector<int > parents_with_callback;
78
93
79
- for (auto parent_id : parents (child_id)) {
80
- auto & parent_event = event (parent_id);
81
- auto parent_status = parent_event.Query ();
94
+ for (auto parent_id : parents (child_id)) {
95
+ auto & parent_event = event (parent_id);
96
+ auto parent_status = parent_event.Query ();
82
97
83
- if (parent_status == EventStatus::EVENT_FAILED) {
84
- parent_failed = true ;
85
- break ;
86
- } else if (parent_status == EventStatus::EVENT_SCHEDULED) {
87
- // parent is not finished yet, check if this is blocking us
88
- // from scheduling a child
89
- if (!canSchedule (parent_id, child_id)) {
90
- // we can't schedule a child because of this parent,
91
- // check if parent supports callback
92
- if (parent_event.SupportsCallback ()) {
93
- parents_with_callback.push_back (parent_id);
94
- } else {
95
- parent_needs_polling = true ;
96
- break ;
98
+ if (parent_status == EventStatus::EVENT_FAILED) {
99
+ parent_failed = true ;
100
+ break ;
101
+ } else if (parent_status == EventStatus::EVENT_SCHEDULED) {
102
+ // parent is not finished yet, check if this is blocking us
103
+ // from scheduling a child
104
+ if (!canSchedule (parent_id, child_id)) {
105
+ // we can't schedule a child because of this parent,
106
+ // check if parent supports callback
107
+ if (parent_event.SupportsCallback ()) {
108
+ parents_with_callback.push_back (parent_id);
109
+ } else {
110
+ parent_needs_polling = true ;
111
+ break ;
112
+ }
97
113
}
114
+ } else if (parent_status != EventStatus::EVENT_SUCCESS) {
115
+ VLOG (1 ) << " Unexpected parent task state: " << parent_status
116
+ << " , task id: " << child_id
117
+ << " , parent task id: " << parent_id;
118
+ parent_failed = true ;
119
+ break ;
98
120
}
99
- } else if (parent_status != EventStatus::EVENT_SUCCESS) {
100
- VLOG (1 ) << " Unexpected parent task state: " << parent_status
101
- << " , task id: " << child_id
102
- << " , parent task id: " << parent_id;
103
- parent_failed = true ;
104
- break ;
105
121
}
106
- }
107
122
108
- if (parent_failed) {
109
- // one of parents failed, set failure flag and wrap up execution
110
- success_ = false ;
111
- schedule (child_id, isInlineTask (task_id, child_id));
112
- } else if (parent_needs_polling) {
113
- // some parents are blocking us from scheduling a child and don't
114
- // support callbacks, using polling
115
- const auto & child_device_option = event (child_id).GetDeviceOption ();
116
- pool (child_device_option)
117
- ->run (std::bind (
118
- &AsyncSchedulingNet::pollAndSchedule, this , child_id));
119
- } else if (!parents_with_callback.empty ()) {
120
- // some parents are blocking us from scheduling a child and they
121
- // support callbacks
122
- for (auto parent_id : parents_with_callback) {
123
- event (parent_id).SetCallback (std::bind (
124
- &AsyncSchedulingNet::parentCallback, this , parent_id));
123
+ if (parent_failed) {
124
+ // one of parents failed, set failure flag and wrap up execution
125
+ success_ = false ;
126
+ schedule (child_id, isInlineTask (task_id, child_id));
127
+ } else if (parent_needs_polling) {
128
+ // some parents are blocking us from scheduling a child and don't
129
+ // support callbacks, using polling
130
+ const auto & child_device_option =
131
+ event (child_id).GetDeviceOption ();
132
+ pool (child_device_option)
133
+ ->run (std::bind (
134
+ &AsyncSchedulingNet::pollAndSchedule, this , child_id));
135
+ } else if (!parents_with_callback.empty ()) {
136
+ // some parents are blocking us from scheduling a child and they
137
+ // support callbacks
138
+ for (auto parent_id : parents_with_callback) {
139
+ event (parent_id).SetCallback (std::bind (
140
+ &AsyncSchedulingNet::parentCallback, this , parent_id));
141
+ }
142
+ } else {
143
+ // we're ready to schedule a child
144
+ schedule (child_id, isInlineTask (task_id, child_id));
125
145
}
126
- } else {
127
- // we're ready to schedule a child
128
- schedule (child_id, isInlineTask (task_id, child_id));
129
146
}
130
147
}
131
148
}
132
- }
133
149
134
- // In case of net's failure, make sure all pending tasks are finished
135
- if (!success_) {
136
- // Simple logic to capture all pending tasks - check all tasks
137
- // at the end of each task in case of net's failure
138
- for (auto tid = 0 ; tid < tasksNum (); ++tid) {
139
- if (event (tid).Query () == EventStatus::EVENT_SCHEDULED) {
140
- // SetFinished may throw, e.g. when we call it on already finished
141
- // event, and in some other cases (CUDA)
142
- try {
143
- event (tid).SetFinished (" Cancelled" );
144
- } catch (const EnforceNotMet&) {
145
- // ignore
150
+ // In case of net's failure, make sure all pending tasks are finished
151
+ if (!success_) {
152
+ // Simple logic to capture all pending tasks - check all tasks
153
+ // at the end of each task in case of net's failure
154
+ for (auto tid = 0 ; tid < tasksNum (); ++tid) {
155
+ if (event (tid).Query () == EventStatus::EVENT_SCHEDULED) {
156
+ // SetFinished may throw, e.g. when we call it on already finished
157
+ // event, and in some other cases (CUDA)
158
+ try {
159
+ event (tid).SetFinished (" Cancelled" );
160
+ } catch (const EnforceNotMet&) {
161
+ // ignore
162
+ }
146
163
}
147
164
}
148
165
}
149
- }
150
166
151
- // finishRun may cause waiters to wake up and destroy the net,
152
- // before we call finishRun we need to make sure all other (finishing)
153
- // tasks are done;
154
- // Bumping and checking the counter after the task's job is done
155
- auto tasks_num = tasksNum ();
156
- auto cur_processed_tasks = ++processed_tasks_num_;
157
- if (cur_processed_tasks == tasks_num) {
158
- finishRun ();
167
+ // finishRun may cause waiters to wake up and destroy the net,
168
+ // before we call finishRun we need to make sure all other (finishing)
169
+ // tasks are done;
170
+ // Bumping and checking the counter after the task's job is done
171
+ auto tasks_num = tasksNum ();
172
+ auto cur_processed_tasks = ++processed_tasks_num_;
173
+ if (cur_processed_tasks == tasks_num) {
174
+ finishRun ();
175
+ }
176
+ } catch (const std::exception & e) {
177
+ // error of core scheduling and/or logic, will call terminate
178
+ LOG (FATAL) << " Unexpected error during graph scheduling run: "
179
+ << e.what ();
180
+ } catch (...) {
181
+ LOG (FATAL) << " Unknown error during graph scheduling run" ;
159
182
}
160
183
};
161
184
@@ -215,33 +238,33 @@ void AsyncSchedulingNet::finishRun() {
215
238
216
239
bool AsyncSchedulingNet::RunAsync () {
217
240
try {
218
- {
219
- std::unique_lock<std::mutex> lock (running_mutex_);
220
- if (running_) {
221
- LOG (ERROR) << " Detected concurrent runs" ;
222
- return false ;
223
- }
224
- running_ = true ;
225
- reset ();
226
-
227
- StartAllObservers ();
228
- tracing::startIter (tracer_);
229
- if (options_.report_stats_ ) {
230
- counters_.ReportRunStart ();
231
- }
241
+ std::unique_lock<std::mutex> lock (running_mutex_);
242
+ if (running_) {
243
+ LOG (ERROR) << " Detected concurrent runs" ;
244
+ return false ;
232
245
}
246
+ running_ = true ;
247
+ reset ();
233
248
234
- for ( auto task_id = 0 ; task_id < tasksNum (); ++task_id) {
235
- if ( parents (task_id). empty ()) {
236
- schedule (task_id, options_.run_root_tasks_inline_ );
237
- }
249
+ StartAllObservers ();
250
+ tracing::startIter (tracer_);
251
+ if ( options_.report_stats_ ) {
252
+ counters_. ReportRunStart ();
238
253
}
239
254
} catch (const std::exception & e) {
240
255
LOG (ERROR) << " Exception while starting an async run: " << e.what ();
241
256
finishRun ();
242
257
return false ;
243
258
}
244
259
260
+ // schedule() is not expected to throw, at this moment all the initial tasks
261
+ // will be scheduled and the full graph of tasks will be executed
262
+ for (auto task_id = 0 ; task_id < tasksNum (); ++task_id) {
263
+ if (parents (task_id).empty ()) {
264
+ schedule (task_id, options_.run_root_tasks_inline_ );
265
+ }
266
+ }
267
+
245
268
if (tasksNum () == 0 ) {
246
269
finishRun ();
247
270
}
0 commit comments