@@ -53,11 +53,18 @@ int getCounterForNetName(const std::string& net_name) {
53
53
return counter;
54
54
}
55
55
56
- Tracer::Tracer (const NetBase* net, const std::string& net_name)
57
- : net_(net), filename_(net_name), iter_(0 ) {
56
+ Tracer::Tracer (
57
+ const NetBase* net,
58
+ const std::string& net_name,
59
+ TracingConfig config)
60
+ : net_(net),
61
+ filename_ (net_name),
62
+ iter_(0 ),
63
+ dumping_iter_(0 ),
64
+ config_(config) {
58
65
std::replace (filename_.begin (), filename_.end (), ' /' , ' _' );
59
- filename_ = FLAGS_caffe2_net_async_tracing_filepath + " /" + filename_ +
60
- + " _id_ " + caffe2::to_string (getCounterForNetName (net_name));
66
+ filename_ = this -> config (). filepath + " /" + filename_ + " _id_ " +
67
+ caffe2::to_string (getCounterForNetName (net_name));
61
68
timer_.Start ();
62
69
}
63
70
@@ -251,6 +258,10 @@ int Tracer::bumpIter() {
251
258
return iter_++;
252
259
}
253
260
261
+ int Tracer::bumpDumpingIter () {
262
+ return dumping_iter_++;
263
+ }
264
+
254
265
void Tracer::dumpTracingResultAndClearEvents (const std::string& file_suffix) {
255
266
if (events_.empty () || filename_.empty ()) {
256
267
return ;
@@ -294,7 +305,9 @@ void TracerGuard::addArgument(TracingField field, const char* value) {
294
305
event_.category_ = value;
295
306
break ;
296
307
}
297
- default : { CAFFE_THROW (" Unexpected tracing string field " , field); }
308
+ default : {
309
+ CAFFE_THROW (" Unexpected tracing string field " , field);
310
+ }
298
311
}
299
312
}
300
313
@@ -316,7 +329,9 @@ void TracerGuard::addArgument(TracingField field, int value) {
316
329
event_.thread_label_ = value;
317
330
break ;
318
331
}
319
- default : { CAFFE_THROW (" Unexpected tracing int field " , field); }
332
+ default : {
333
+ CAFFE_THROW (" Unexpected tracing int field " , field);
334
+ }
320
335
}
321
336
}
322
337
@@ -388,25 +403,66 @@ bool hasEnableTracingFlag(const NetBase* net) {
388
403
return GetFlagArgument (net->debug_def (), " enable_tracing" , false );
389
404
}
390
405
406
+ TracingConfig getTracingConfigFromNet (const NetBase* net) {
407
+ ArgumentHelper arg_helper (net->debug_def ());
408
+ TracingConfig cfg;
409
+
410
+ cfg.mode = (arg_helper.GetSingleArgument <std::string>(" tracing_mode" , " " ) ==
411
+ " GLOBAL_TIMESLICE" )
412
+ ? TracingMode::GLOBAL_TIMESLICE
413
+ : TracingMode::EVERY_K_ITERATIONS;
414
+
415
+ cfg.filepath = arg_helper.GetSingleArgument <std::string>(
416
+ " tracing_filepath" , FLAGS_caffe2_net_async_tracing_filepath);
417
+
418
+ cfg.trace_every_nth_batch = arg_helper.GetSingleArgument <int >(
419
+ " trace_every_nth_batch" , FLAGS_caffe2_net_async_tracing_nth);
420
+ cfg.dump_every_nth_batch = arg_helper.GetSingleArgument <int >(
421
+ " dump_every_nth_batch" , FLAGS_caffe2_net_async_tracing_dumping_nth);
422
+
423
+ cfg.trace_for_n_ms =
424
+ arg_helper.GetSingleArgument <int >(" trace_for_n_ms" , cfg.trace_for_n_ms );
425
+ cfg.trace_every_n_ms = arg_helper.GetSingleArgument <int >(
426
+ " trace_every_n_ms" , cfg.trace_every_n_ms );
427
+
428
+ return cfg;
429
+ };
430
+
391
431
std::shared_ptr<Tracer> create (
392
432
const NetBase* net,
393
433
const std::string& net_name) {
394
434
// Enable the tracer if the net has the "enable_tracing" argument set OR
395
435
// if the command line option includes the net name option in the list of
396
436
// tracable nets.
397
437
bool trace_net = hasEnableTracingFlag (net) || isTraceableNetName (net_name);
398
- return trace_net ? std::make_shared<Tracer>(net, net_name) : nullptr ;
438
+ return trace_net
439
+ ? std::make_shared<Tracer>(net, net_name, getTracingConfigFromNet (net))
440
+ : nullptr ;
399
441
}
400
442
401
443
bool startIter (const std::shared_ptr<Tracer>& tracer) {
402
444
if (!tracer) {
403
445
return false ;
404
446
}
405
447
auto iter = tracer->bumpIter ();
406
- auto is_enabled = iter % FLAGS_caffe2_net_async_tracing_nth == 0 ;
448
+ bool is_enabled;
449
+ bool should_dump;
450
+ if (tracer->config ().mode == TracingMode::EVERY_K_ITERATIONS) {
451
+ is_enabled = iter % tracer->config ().trace_every_nth_batch == 0 ;
452
+ should_dump = iter % tracer->config ().dump_every_nth_batch == 0 ;
453
+ } else {
454
+ using namespace std ::chrono;
455
+ auto ms =
456
+ duration_cast<milliseconds>(system_clock::now ().time_since_epoch ())
457
+ .count ();
458
+ is_enabled = (ms % tracer->config ().trace_every_n_ms ) <
459
+ tracer->config ().trace_for_n_ms ;
460
+ // dump just after disabled tracing
461
+ should_dump = tracer->isEnabled () && !is_enabled;
462
+ }
407
463
tracer->setEnabled (is_enabled);
408
- if (iter % FLAGS_caffe2_net_async_tracing_dumping_nth == 0 ) {
409
- int dumping_iter = iter / FLAGS_caffe2_net_async_tracing_dumping_nth ;
464
+ if (should_dump ) {
465
+ int dumping_iter = tracer-> bumpDumpingIter () ;
410
466
tracer->dumpTracingResultAndClearEvents (caffe2::to_string (dumping_iter));
411
467
}
412
468
return is_enabled;
0 commit comments