@@ -32,130 +32,153 @@ bool AsyncSchedulingNet::isInlineTask(int parent_id, int child_id) const {
3232 last_parent_op->device_option (), first_child_op->device_option ());
3333}
3434
35- void AsyncSchedulingNet::schedule (int task_id, bool run_inline) {
35+ // schedule() is not supposed to throw, all exceptions in the ops are caught
36+ // and reported in the end of the graph's execution, the full graph of tasks
37+ // is expected to be scheduled
38+ void AsyncSchedulingNet::schedule (int task_id, bool run_inline) noexcept {
3639 if (!testAndSetScheduled (task_id)) {
3740 return ;
3841 }
3942 auto schedule_func = [this , task_id]() {
40- if (success_) {
41- int stream_id = 0 ;
42- if (options_.streams_per_gpu_ > 1 ) {
43- stream_id = stream (task_id);
44- }
45- if (!run (task_id, stream_id)) {
46- success_ = false ;
43+ try {
44+ if (success_) {
45+ int stream_id = 0 ;
46+ if (options_.streams_per_gpu_ > 1 ) {
47+ try {
48+ stream_id = stream (task_id);
49+ } catch (const std::exception& e) {
50+ C10_LOG_EVERY_MS (ERROR, 1000 )
51+ << " Failed to select a stream: " << e.what ();
52+ }
53+ }
54+ if (!run (task_id, stream_id)) {
55+ success_ = false ;
56+ }
4757 }
48- }
4958
50- if (options_.report_stats_ ) {
51- auto last_op_id = lastTaskOpId (task_id);
52- auto * last_op = lastTaskOp (task_id);
53- if (last_op->device_option ().device_type () == PROTO_CPU &&
54- last_op->HasAsyncPart ()) {
55- last_op->event ().SetCallback (
56- [this , last_op_id] { counters_.AddPerOpAsyncEndTime (last_op_id); });
59+ if (options_.report_stats_ ) {
60+ try {
61+ auto last_op_id = lastTaskOpId (task_id);
62+ auto * last_op = lastTaskOp (task_id);
63+ if (last_op->device_option ().device_type () == PROTO_CPU &&
64+ last_op->HasAsyncPart ()) {
65+ last_op->event ().SetCallback ([this , last_op_id] {
66+ counters_.AddPerOpAsyncEndTime (last_op_id);
67+ });
68+ }
69+ } catch (const std::exception& e) {
70+ C10_LOG_EVERY_MS (ERROR, 1000 )
71+ << " Failed to report operator stats: " << e.what ();
72+ }
5773 }
58- }
5974
60- for (auto child_id : children (task_id)) {
61- int parent_count = updateParentCount (child_id);
62- if (parent_count == 0 ) {
63- // Schedule a child if:
64- // - there is failure, we skip an op execution and finish the job
65- // - forced scheduling though always_schedule_child_
66- // - finish_chain_ is set, in this case parents are
67- // guaranteed to be finished
68- // - in all other cases, check parents with canSchedule
69- if (!success_ || options_.always_schedule_child_ ||
70- options_.finish_chain_ || canSchedule (child_id)) {
71- // if DFS scheduling is enabled, run children inline,
72- // ignore DFS scheduling in callbacks
73- schedule (child_id, isInlineTask (task_id, child_id));
74- } else {
75- bool parent_failed = false ;
76- bool parent_needs_polling = false ;
77- std::vector<int > parents_with_callback;
75+ for (auto child_id : children (task_id)) {
76+ int parent_count = updateParentCount (child_id);
77+ if (parent_count == 0 ) {
78+ // Schedule a child if:
79+ // - there is failure, we skip an op execution and finish the job
80+ // - forced scheduling though always_schedule_child_
81+ // - finish_chain_ is set, in this case parents are
82+ // guaranteed to be finished
83+ // - in all other cases, check parents with canSchedule
84+ if (!success_ || options_.always_schedule_child_ ||
85+ options_.finish_chain_ || canSchedule (child_id)) {
86+ // if DFS scheduling is enabled, run children inline,
87+ // ignore DFS scheduling in callbacks
88+ schedule (child_id, isInlineTask (task_id, child_id));
89+ } else {
90+ bool parent_failed = false ;
91+ bool parent_needs_polling = false ;
92+ std::vector<int > parents_with_callback;
7893
79- for (auto parent_id : parents (child_id)) {
80- auto & parent_event = event (parent_id);
81- auto parent_status = parent_event.Query ();
94+ for (auto parent_id : parents (child_id)) {
95+ auto & parent_event = event (parent_id);
96+ auto parent_status = parent_event.Query ();
8297
83- if (parent_status == EventStatus::EVENT_FAILED) {
84- parent_failed = true ;
85- break ;
86- } else if (parent_status == EventStatus::EVENT_SCHEDULED) {
87- // parent is not finished yet, check if this is blocking us
88- // from scheduling a child
89- if (!canSchedule (parent_id, child_id)) {
90- // we can't schedule a child because of this parent,
91- // check if parent supports callback
92- if (parent_event.SupportsCallback ()) {
93- parents_with_callback.push_back (parent_id);
94- } else {
95- parent_needs_polling = true ;
96- break ;
98+ if (parent_status == EventStatus::EVENT_FAILED) {
99+ parent_failed = true ;
100+ break ;
101+ } else if (parent_status == EventStatus::EVENT_SCHEDULED) {
102+ // parent is not finished yet, check if this is blocking us
103+ // from scheduling a child
104+ if (!canSchedule (parent_id, child_id)) {
105+ // we can't schedule a child because of this parent,
106+ // check if parent supports callback
107+ if (parent_event.SupportsCallback ()) {
108+ parents_with_callback.push_back (parent_id);
109+ } else {
110+ parent_needs_polling = true ;
111+ break ;
112+ }
97113 }
114+ } else if (parent_status != EventStatus::EVENT_SUCCESS) {
115+ VLOG (1 ) << " Unexpected parent task state: " << parent_status
116+ << " , task id: " << child_id
117+ << " , parent task id: " << parent_id;
118+ parent_failed = true ;
119+ break ;
98120 }
99- } else if (parent_status != EventStatus::EVENT_SUCCESS) {
100- VLOG (1 ) << " Unexpected parent task state: " << parent_status
101- << " , task id: " << child_id
102- << " , parent task id: " << parent_id;
103- parent_failed = true ;
104- break ;
105121 }
106- }
107122
108- if (parent_failed) {
109- // one of parents failed, set failure flag and wrap up execution
110- success_ = false ;
111- schedule (child_id, isInlineTask (task_id, child_id));
112- } else if (parent_needs_polling) {
113- // some parents are blocking us from scheduling a child and don't
114- // support callbacks, using polling
115- const auto & child_device_option = event (child_id).GetDeviceOption ();
116- pool (child_device_option)
117- ->run (std::bind (
118- &AsyncSchedulingNet::pollAndSchedule, this , child_id));
119- } else if (!parents_with_callback.empty ()) {
120- // some parents are blocking us from scheduling a child and they
121- // support callbacks
122- for (auto parent_id : parents_with_callback) {
123- event (parent_id).SetCallback (std::bind (
124- &AsyncSchedulingNet::parentCallback, this , parent_id));
123+ if (parent_failed) {
124+ // one of parents failed, set failure flag and wrap up execution
125+ success_ = false ;
126+ schedule (child_id, isInlineTask (task_id, child_id));
127+ } else if (parent_needs_polling) {
128+ // some parents are blocking us from scheduling a child and don't
129+ // support callbacks, using polling
130+ const auto & child_device_option =
131+ event (child_id).GetDeviceOption ();
132+ pool (child_device_option)
133+ ->run (std::bind (
134+ &AsyncSchedulingNet::pollAndSchedule, this , child_id));
135+ } else if (!parents_with_callback.empty ()) {
136+ // some parents are blocking us from scheduling a child and they
137+ // support callbacks
138+ for (auto parent_id : parents_with_callback) {
139+ event (parent_id).SetCallback (std::bind (
140+ &AsyncSchedulingNet::parentCallback, this , parent_id));
141+ }
142+ } else {
143+ // we're ready to schedule a child
144+ schedule (child_id, isInlineTask (task_id, child_id));
125145 }
126- } else {
127- // we're ready to schedule a child
128- schedule (child_id, isInlineTask (task_id, child_id));
129146 }
130147 }
131148 }
132- }
133149
134- // In case of net's failure, make sure all pending tasks are finished
135- if (!success_) {
136- // Simple logic to capture all pending tasks - check all tasks
137- // at the end of each task in case of net's failure
138- for (auto tid = 0 ; tid < tasksNum (); ++tid) {
139- if (event (tid).Query () == EventStatus::EVENT_SCHEDULED) {
140- // SetFinished may throw, e.g. when we call it on already finished
141- // event, and in some other cases (CUDA)
142- try {
143- event (tid).SetFinished (" Cancelled" );
144- } catch (const EnforceNotMet&) {
145- // ignore
150+ // In case of net's failure, make sure all pending tasks are finished
151+ if (!success_) {
152+ // Simple logic to capture all pending tasks - check all tasks
153+ // at the end of each task in case of net's failure
154+ for (auto tid = 0 ; tid < tasksNum (); ++tid) {
155+ if (event (tid).Query () == EventStatus::EVENT_SCHEDULED) {
156+ // SetFinished may throw, e.g. when we call it on already finished
157+ // event, and in some other cases (CUDA)
158+ try {
159+ event (tid).SetFinished (" Cancelled" );
160+ } catch (const EnforceNotMet&) {
161+ // ignore
162+ }
146163 }
147164 }
148165 }
149- }
150166
151- // finishRun may cause waiters to wake up and destroy the net,
152- // before we call finishRun we need to make sure all other (finishing)
153- // tasks are done;
154- // Bumping and checking the counter after the task's job is done
155- auto tasks_num = tasksNum ();
156- auto cur_processed_tasks = ++processed_tasks_num_;
157- if (cur_processed_tasks == tasks_num) {
158- finishRun ();
167+ // finishRun may cause waiters to wake up and destroy the net,
168+ // before we call finishRun we need to make sure all other (finishing)
169+ // tasks are done;
170+ // Bumping and checking the counter after the task's job is done
171+ auto tasks_num = tasksNum ();
172+ auto cur_processed_tasks = ++processed_tasks_num_;
173+ if (cur_processed_tasks == tasks_num) {
174+ finishRun ();
175+ }
176+ } catch (const std::exception& e) {
177+ // error of core scheduling and/or logic, will call terminate
178+ LOG (FATAL) << " Unexpected error during graph scheduling run: "
179+ << e.what ();
180+ } catch (...) {
181+ LOG (FATAL) << " Unknown error during graph scheduling run" ;
159182 }
160183 };
161184
@@ -215,33 +238,33 @@ void AsyncSchedulingNet::finishRun() {
215238
216239bool AsyncSchedulingNet::RunAsync () {
217240 try {
218- {
219- std::unique_lock<std::mutex> lock (running_mutex_);
220- if (running_) {
221- LOG (ERROR) << " Detected concurrent runs" ;
222- return false ;
223- }
224- running_ = true ;
225- reset ();
226-
227- StartAllObservers ();
228- tracing::startIter (tracer_);
229- if (options_.report_stats_ ) {
230- counters_.ReportRunStart ();
231- }
241+ std::unique_lock<std::mutex> lock (running_mutex_);
242+ if (running_) {
243+ LOG (ERROR) << " Detected concurrent runs" ;
244+ return false ;
232245 }
246+ running_ = true ;
247+ reset ();
233248
234- for ( auto task_id = 0 ; task_id < tasksNum (); ++task_id) {
235- if ( parents (task_id). empty ()) {
236- schedule (task_id, options_.run_root_tasks_inline_ );
237- }
249+ StartAllObservers ();
250+ tracing::startIter (tracer_);
251+ if ( options_.report_stats_ ) {
252+ counters_. ReportRunStart ();
238253 }
239254 } catch (const std::exception& e) {
240255 LOG (ERROR) << " Exception while starting an async run: " << e.what ();
241256 finishRun ();
242257 return false ;
243258 }
244259
260+ // schedule() is not expected to throw, at this moment all the initial tasks
261+ // will be scheduled and the full graph of tasks will be executed
262+ for (auto task_id = 0 ; task_id < tasksNum (); ++task_id) {
263+ if (parents (task_id).empty ()) {
264+ schedule (task_id, options_.run_root_tasks_inline_ );
265+ }
266+ }
267+
245268 if (tasksNum () == 0 ) {
246269 finishRun ();
247270 }
0 commit comments