@@ -130,6 +130,7 @@ class CacheEntry {
130
130
// The caller is expected to check `GlobalCallbackManager::get().version()'
131
131
// and call CacheEntry::update() if necessary.
132
132
StepCallbacks getActiveCallbacks ();
133
+ c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty ();
133
134
134
135
// Full rebuild. (E.g. during registration)
135
136
void update (const std::vector<RecordFunctionCallback>& callbacks);
@@ -142,6 +143,8 @@ class CacheEntry {
142
143
int tries_left_{-1 };
143
144
};
144
145
146
+ C10_ALWAYS_INLINE void getActiveCallbacksImpl ();
147
+
145
148
void rebuildActiveCallbacks ();
146
149
int sampleTries (double p) const ;
147
150
@@ -169,6 +172,7 @@ class LocalCallbackManager {
169
172
public:
170
173
const RecordFunctionTLS& getTLS () const ;
171
174
StepCallbacks getActiveCallbacks (const RecordScope scope);
175
+ c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty (const RecordScope scope);
172
176
173
177
void setTLS (const RecordFunctionTLS& tls);
174
178
void seed (uint32_t seed);
@@ -178,6 +182,8 @@ class LocalCallbackManager {
178
182
void clearCallbacks ();
179
183
180
184
private:
185
+ void rebuildActiveCallbacksIfNeeded ();
186
+
181
187
void rebuild_all (const GlobalCallbackManager::snapshot_t & global_snapshot);
182
188
183
189
void rebuild_callback_scopes (
@@ -271,7 +277,7 @@ void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) {
271
277
rebuildActiveCallbacks ();
272
278
}
273
279
274
- StepCallbacks CacheEntry::getActiveCallbacks () {
280
+ void CacheEntry::getActiveCallbacksImpl () {
275
281
// We rebuild the active set when `sampling_countdown_` reaches zero, so if it
276
282
// reaches zero at the start of this function something has gone wrong.
277
283
TORCH_INTERNAL_ASSERT (sampling_countdown_ > 0 , sampling_countdown_);
@@ -295,7 +301,18 @@ StepCallbacks CacheEntry::getActiveCallbacks() {
295
301
}
296
302
}
297
303
}
304
+ }
298
305
306
+ StepCallbacks CacheEntry::getActiveCallbacks () {
307
+ getActiveCallbacksImpl ();
308
+ return active_callbacks_;
309
+ }
310
+
311
+ c10::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty () {
312
+ getActiveCallbacksImpl ();
313
+ if (C10_LIKELY (active_callbacks_.empty ())) {
314
+ return c10::nullopt;
315
+ }
299
316
return active_callbacks_;
300
317
}
301
318
@@ -365,15 +382,25 @@ const RecordFunctionTLS& LocalCallbackManager::getTLS() const {
365
382
return registered_callbacks_;
366
383
}
367
384
368
- StepCallbacks LocalCallbackManager::getActiveCallbacks (
369
- const RecordScope scope) {
385
+ void LocalCallbackManager::rebuildActiveCallbacksIfNeeded () {
370
386
const auto global_version = GlobalCallbackManager::get ().version ();
371
387
if (C10_UNLIKELY (global_version != global_version_)) {
372
388
rebuild_all (GlobalCallbackManager::get ().getSnapshot ());
373
389
}
390
+ }
391
+
392
+ StepCallbacks LocalCallbackManager::getActiveCallbacks (
393
+ const RecordScope scope) {
394
+ rebuildActiveCallbacksIfNeeded ();
374
395
return active_callbacks_[static_cast <size_t >(scope)].getActiveCallbacks ();
375
396
}
376
397
398
+ c10::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty (
399
+ const RecordScope scope) {
400
+ rebuildActiveCallbacksIfNeeded ();
401
+ return active_callbacks_[static_cast <size_t >(scope)].getActiveCallbacksUnlessEmpty ();
402
+ }
403
+
377
404
void LocalCallbackManager::setTLS (const RecordFunctionTLS& tls) {
378
405
registered_callbacks_ = tls;
379
406
rebuild_all (GlobalCallbackManager::get ().getSnapshot ());
@@ -572,6 +599,10 @@ StepCallbacks getStepCallbacks(RecordScope scope) {
572
599
return LocalCallbackManager::get ().getActiveCallbacks (scope);
573
600
}
574
601
602
+ c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty (RecordScope scope) {
603
+ return LocalCallbackManager::get ().getActiveCallbacksUnlessEmpty (scope);
604
+ }
605
+
575
606
const RecordFunctionTLS& get_record_function_tls_ () {
576
607
return LocalCallbackManager::get ().getTLS ();
577
608
}
0 commit comments