Skip to content

Commit c083489

Browse files
swolchokpytorchmergebot
authored andcommitted
[kineto] Optimize getStepCallbacks for common case of no active callbacks
Pull Request resolved: pytorch#77804 IIUC, the result of this function will be empty and unused if there are no sampled callbacks, which is the common case. We can accelerate this case by wrapping the result in an optional to save initializing an empty SmallVector. Differential Revision: [D36497279](https://our.internmc.facebook.com/intern/diff/D36497279/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36497279/)! Approved by: https://github.com/robieta
1 parent 02c4d87 commit c083489

File tree

7 files changed

+59
-25
lines changed

7 files changed

+59
-25
lines changed

aten/src/ATen/core/dispatch/Dispatcher.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,9 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandl
545545
.template getDispatchKeySetUnboxed<Args...>(args...);
546546
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
547547
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
548-
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION);
549-
if (C10_UNLIKELY(!step_callbacks.empty() && op.operatorDef_->op.isObserved())) {
550-
return callWithDispatchKeySlowPath<Return, Args...>(op, step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
548+
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
549+
if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
550+
return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
551551
}
552552
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
553553
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
@@ -568,9 +568,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
568568
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
569569
const auto& kernel = entry.lookup(dispatchKeySet);
570570
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
571-
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION);
572-
if (C10_UNLIKELY(!step_callbacks.empty() && entry.isObserved())) {
573-
at::RecordFunction guard(std::move(step_callbacks));
571+
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
572+
if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
573+
at::RecordFunction guard(std::move(*step_callbacks));
574574
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
575575
auto& schema = op.schema();
576576
auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);

aten/src/ATen/record_function.cpp

+34-3
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class CacheEntry {
130130
// The caller is expected to check `GlobalCallbackManager::get().version()'
131131
// and call CacheEntry::update() if necessary.
132132
StepCallbacks getActiveCallbacks();
133+
c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty();
133134

134135
// Full rebuild. (E.g. during registration)
135136
void update(const std::vector<RecordFunctionCallback>& callbacks);
@@ -142,6 +143,8 @@ class CacheEntry {
142143
int tries_left_{-1};
143144
};
144145

146+
C10_ALWAYS_INLINE void getActiveCallbacksImpl();
147+
145148
void rebuildActiveCallbacks();
146149
int sampleTries(double p) const;
147150

@@ -169,6 +172,7 @@ class LocalCallbackManager {
169172
public:
170173
const RecordFunctionTLS& getTLS() const;
171174
StepCallbacks getActiveCallbacks(const RecordScope scope);
175+
c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(const RecordScope scope);
172176

173177
void setTLS(const RecordFunctionTLS& tls);
174178
void seed(uint32_t seed);
@@ -178,6 +182,8 @@ class LocalCallbackManager {
178182
void clearCallbacks();
179183

180184
private:
185+
void rebuildActiveCallbacksIfNeeded();
186+
181187
void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot);
182188

183189
void rebuild_callback_scopes(
@@ -271,7 +277,7 @@ void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) {
271277
rebuildActiveCallbacks();
272278
}
273279

274-
StepCallbacks CacheEntry::getActiveCallbacks() {
280+
void CacheEntry::getActiveCallbacksImpl() {
275281
// We rebuild the active set when `sampling_countdown_` reaches zero, so if it
276282
// reaches zero at the start of this function something has gone wrong.
277283
TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_);
@@ -295,7 +301,18 @@ StepCallbacks CacheEntry::getActiveCallbacks() {
295301
}
296302
}
297303
}
304+
}
298305

306+
StepCallbacks CacheEntry::getActiveCallbacks() {
307+
getActiveCallbacksImpl();
308+
return active_callbacks_;
309+
}
310+
311+
c10::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() {
312+
getActiveCallbacksImpl();
313+
if (C10_LIKELY(active_callbacks_.empty())) {
314+
return c10::nullopt;
315+
}
299316
return active_callbacks_;
300317
}
301318

@@ -365,15 +382,25 @@ const RecordFunctionTLS& LocalCallbackManager::getTLS() const {
365382
return registered_callbacks_;
366383
}
367384

368-
StepCallbacks LocalCallbackManager::getActiveCallbacks(
369-
const RecordScope scope) {
385+
void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() {
370386
const auto global_version = GlobalCallbackManager::get().version();
371387
if (C10_UNLIKELY(global_version != global_version_)) {
372388
rebuild_all(GlobalCallbackManager::get().getSnapshot());
373389
}
390+
}
391+
392+
StepCallbacks LocalCallbackManager::getActiveCallbacks(
393+
const RecordScope scope) {
394+
rebuildActiveCallbacksIfNeeded();
374395
return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks();
375396
}
376397

398+
c10::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty(
399+
const RecordScope scope) {
400+
rebuildActiveCallbacksIfNeeded();
401+
return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacksUnlessEmpty();
402+
}
403+
377404
void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) {
378405
registered_callbacks_ = tls;
379406
rebuild_all(GlobalCallbackManager::get().getSnapshot());
@@ -572,6 +599,10 @@ StepCallbacks getStepCallbacks(RecordScope scope) {
572599
return LocalCallbackManager::get().getActiveCallbacks(scope);
573600
}
574601

602+
c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope) {
603+
return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope);
604+
}
605+
575606
const RecordFunctionTLS& get_record_function_tls_() {
576607
return LocalCallbackManager::get().getTLS();
577608
}

aten/src/ATen/record_function.h

+2
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ struct TORCH_API RecordFunction {
478478

479479
TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
480480

481+
TORCH_API c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope);
482+
481483
namespace detail {
482484
template <typename Inputs, typename F, typename... Args>
483485
void record_function_with_scope(RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) {

binaries/record_function_benchmark.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
#include <torch/torch.h>
23
#include <ATen/record_function.h>
34

@@ -49,9 +50,9 @@ float runPureRecordFunctionBench(int iter) {
4950
typedef std::chrono::microseconds us;
5051
std::chrono::time_point<clock> start_time = clock::now();
5152
for (auto idx = 0; idx < iter; ++idx) {
52-
auto step_callbacks = at::getStepCallbacks(at::RecordScope::USER_SCOPE);
53-
if (!step_callbacks.empty()) {
54-
at::RecordFunction guard(std::move(step_callbacks));
53+
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE);
54+
if (step_callbacks.has_value()) {
55+
at::RecordFunction guard(std::move(*step_callbacks));
5556
guard.before("Test", -1);
5657
}
5758
}

torch/csrc/autograd/function.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
151151
// probably operate with names.
152152
at::NoNamesGuard no_names_guard;
153153

154-
auto step_callbacks = at::getStepCallbacks(at::RecordScope::BACKWARD_FUNCTION);
155-
if (!step_callbacks.empty()) {
156-
at::RecordFunction guard(std::move(step_callbacks));
154+
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
155+
if (C10_UNLIKELY(step_callbacks.has_value())) {
156+
at::RecordFunction guard(std::move(*step_callbacks));
157157
// Using sequence number and thread id to correlate with
158158
// the forward pass function
159159
guard.setForwardThreadId(thread_id_);

torch/csrc/jit/runtime/interpreter.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -845,11 +845,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
845845

846846
static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
847847
if (!frame.record_function) {
848-
auto step_callbacks =
849-
at::getStepCallbacks(at::RecordScope::TORCHSCRIPT_FUNCTION);
850-
if (!step_callbacks.empty()) {
848+
auto step_callbacks = at::getStepCallbacksUnlessEmpty(
849+
at::RecordScope::TORCHSCRIPT_FUNCTION);
850+
if (C10_UNLIKELY(step_callbacks.has_value())) {
851851
auto rec_fn =
852-
std::make_unique<at::RecordFunction>(std::move(step_callbacks));
852+
std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
853853
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
854854
if (rec_fn->needsInputs()) {
855855
rec_fn->before(

torch/csrc/jit/runtime/static/impl.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1201,9 +1201,9 @@ c10::IValue BlockRunner::run_impl_record_functions(
12011201
IValueList&& args,
12021202
const KeywordArgs& kwargs) {
12031203
auto step_callbacks =
1204-
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_MODEL);
1205-
if (!step_callbacks.empty()) {
1206-
at::RecordFunction guard(std::move(step_callbacks));
1204+
at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
1205+
if (C10_UNLIKELY(step_callbacks.has_value())) {
1206+
at::RecordFunction guard(std::move(*step_callbacks));
12071207
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
12081208
guard.needsInputs()
12091209
? guard.before(
@@ -1845,9 +1845,9 @@ std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
18451845
void ProcessedNode::run() {
18461846
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
18471847
auto step_callbacks =
1848-
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_OP);
1849-
if (!step_callbacks.empty()) {
1850-
at::RecordFunction guard(std::move(step_callbacks));
1848+
at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP);
1849+
if (C10_UNLIKELY(step_callbacks.has_value())) {
1850+
at::RecordFunction guard(std::move(*step_callbacks));
18511851
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
18521852
if (guard.needsInputs()) {
18531853
const auto inputs = inputs_ivalue_vec();

0 commit comments

Comments
 (0)