Skip to content

Commit 089d658

Browse files
Mikhail Zolotukhinfacebook-github-bot
Mikhail Zolotukhin
authored andcommitted
[TensorExpr] Add classes for memory management in tensor expressions. (pytorch#33216)
Summary: Pull Request resolved: pytorch#33216 All tensor expressions belong to a kernel arena and are freed when the arena is destroyed. Until it is destroyed, all expressions stay valid. Test Plan: Imported from OSS Differential Revision: D19848382 Pulled By: ZolotukhinM fbshipit-source-id: a581ea2b635b9ba2cc53949616a13d8d3a47caae
1 parent 616beb1 commit 089d658

File tree

4 files changed

+121
-0
lines changed

4 files changed

+121
-0
lines changed

Diff for: caffe2/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
454454
${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp
455455
${TORCH_SRC_DIR}/csrc/jit/function.cpp
456456
${TORCH_SRC_DIR}/csrc/jit/vararg_functions.cpp
457+
458+
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
457459
)
458460

459461
if (NOT INTERN_DISABLE_MOBILE_INTERP)

Diff for: tools/build_variables.bzl

+2
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,11 @@ libtorch_sources = [
190190
"torch/csrc/jit/mobile/register_mobile_ops.cpp",
191191
"torch/csrc/jit/mobile/interpreter.cpp",
192192
"torch/csrc/jit/mobile/type_parser.cpp",
193+
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
193194
"torch/csrc/utils/byte_order.cpp",
194195
"torch/csrc/utils/tensor_flatten.cpp",
195196
"torch/csrc/utils/variadic.cpp",
197+
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
196198
]
197199

198200
libtorch_cuda_sources = [

Diff for: torch/csrc/jit/tensorexpr/mem_arena.cpp

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include "torch/csrc/jit/tensorexpr/mem_arena.h"
2+
3+
namespace torch {
4+
namespace jit {
5+
namespace tensorexpr {
6+
7+
namespace {
8+
// Define in an anonymous namespace to hide this symbol from other compilation
9+
// units
10+
thread_local KernelArena* current_arena = nullptr;
11+
}
12+
13+
KernelArena::~KernelArena() {
14+
for (KernelScopedObject* p : kernel_objects_) {
15+
delete p;
16+
}
17+
}
18+
19+
KernelScopedObject::KernelScopedObject() {
20+
KernelArena* kernel = KernelArena::GetCurrentKernelArena();
21+
kernel->kernel_objects_.push_back(this);
22+
}
23+
24+
static std::vector<KernelArena*>& GetKernelArenaStack() {
25+
thread_local std::vector<KernelArena*> kernel_arena_stack;
26+
return kernel_arena_stack;
27+
}
28+
29+
void KernelArena::SetCurrentKernelArena(KernelArena *new_kernel_arena) {
30+
current_arena = new_kernel_arena;
31+
}
32+
33+
KernelArena* KernelArena::GetCurrentKernelArena() {
34+
return current_arena;
35+
}
36+
37+
KernelScope::KernelScope() : owning_(true) {
38+
old_kernel_arena_ = KernelArena::GetCurrentKernelArena();
39+
KernelArena::SetCurrentKernelArena(new KernelArena);
40+
}
41+
42+
KernelScope::KernelScope(KernelArena* arena_) : owning_(false) {
43+
old_kernel_arena_ = KernelArena::GetCurrentKernelArena();
44+
KernelArena::SetCurrentKernelArena(arena_);
45+
}
46+
47+
KernelScope::~KernelScope() {
48+
if (owning_) {
49+
delete KernelArena::GetCurrentKernelArena();
50+
}
51+
KernelArena::SetCurrentKernelArena(old_kernel_arena_);
52+
}
53+
54+
} // namespace tensorexpr
55+
} // namespace jit
56+
} // namespace torch

Diff for: torch/csrc/jit/tensorexpr/mem_arena.h

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#pragma once
2+
#include <vector>
3+
#include "torch/csrc/WindowsTorchApiMacro.h"
4+
5+
namespace torch {
6+
namespace jit {
7+
namespace tensorexpr {
8+
9+
class KernelScopedObject;
10+
11+
// An arena that manages all the underlying kernel-scoped objects.
12+
class KernelArena {
13+
public:
14+
static KernelArena* GetCurrentKernelArena();
15+
static void SetCurrentKernelArena(KernelArena* new_arena);
16+
TORCH_API KernelArena() {}
17+
TORCH_API ~KernelArena();
18+
19+
private:
20+
KernelArena(const KernelArena&) = delete;
21+
KernelArena& operator=(const KernelArena&) = delete;
22+
friend class KernelScopedObject;
23+
std::vector<KernelScopedObject*> kernel_objects_; // owned
24+
};
25+
26+
// A RAII convenience wrapper on top of a kernel.
27+
// It either creates or takes an existing Kernel and sets it as the current
28+
// Kernel. When this object is destroyed, the previous Kernel is set as current,
29+
// and the created kernel is freed. If the kernel was passed, it stays alive.
30+
class KernelScope {
31+
public:
32+
TORCH_API KernelScope();
33+
TORCH_API explicit KernelScope(KernelArena* arena_);
34+
TORCH_API ~KernelScope();
35+
36+
private:
37+
KernelScope(const KernelScope&) = delete;
38+
KernelScope& operator=(const KernelScope&) = delete;
39+
KernelArena* kernel_arena_ = nullptr; // arena to be used in this scope
40+
KernelArena* old_kernel_arena_ =
41+
nullptr; // previous arena, will be restored in destructor
42+
bool owning_ = false; // determines whether the arena will be freed along with
43+
// the scope object
44+
};
45+
46+
// The base object managed by the Kernel.
47+
// The object must be created through "new", and when the Kernel is destroyed,
48+
// All its registered objects are destroyed through "delete".
49+
class TORCH_API KernelScopedObject {
50+
public:
51+
KernelScopedObject();
52+
virtual ~KernelScopedObject() = default;
53+
54+
private:
55+
KernelScopedObject(const KernelScopedObject&) = delete;
56+
KernelScopedObject& operator=(const KernelScopedObject&) = delete;
57+
};
58+
59+
} // namespace tensorexpr
60+
} // namespace jit
61+
} // namespace torch

0 commit comments

Comments
 (0)