Skip to content

Commit 1d26a3a

Browse files
James Reedfacebook-github-bot
James Reed
authored andcommitted
Open registration for c10 thread pool (pytorch#17788)
Summary: 1. Move ATen threadpool & open registration mechanism to C10 2. Move the `global_work_queue` to use this open registration mechanism, to allow users to substitute in their own Pull Request resolved: pytorch#17788 Reviewed By: zdevito Differential Revision: D14379707 Pulled By: jamesr66a fbshipit-source-id: 949662d0024875abf09907d97db927f160c54d45
1 parent 0955592 commit 1d26a3a

17 files changed

+114
-92
lines changed

aten/src/ATen/core/thread_pool.cpp renamed to c10/core/thread_pool.cpp

+27-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
#include <ATen/core/ivalue.h>
2-
#include <ATen/core/thread_pool.h>
1+
#include <c10/core/thread_pool.h>
32

43
namespace c10 {
54

@@ -125,9 +124,32 @@ void setNumThreads(size_t v) {
125124
}
126125
}
127126

128-
ThreadPool& global_work_queue() {
129-
static ThreadPool thread_pool(num_threads.exchange(-1));
130-
return thread_pool;
127+
TaskThreadPoolBase& global_work_queue() {
128+
static std::shared_ptr<TaskThreadPoolBase> pool =
129+
ThreadPoolRegistry()->Create("C10", 0, num_threads.exchange(-1), false);
130+
return *pool;
131131
}
132132

133+
C10_DEFINE_SHARED_REGISTRY(
134+
ThreadPoolRegistry,
135+
TaskThreadPoolBase,
136+
int,
137+
int,
138+
bool);
139+
140+
namespace {
141+
142+
std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
143+
int device_id,
144+
int pool_size,
145+
bool create_new) {
146+
static std::shared_ptr<TaskThreadPoolBase> pool =
147+
std::make_shared<ThreadPool>(pool_size);
148+
return pool;
149+
}
150+
151+
} // namespace
152+
153+
C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
154+
133155
} // namespace c10

aten/src/ATen/core/thread_pool.h renamed to c10/core/thread_pool.h

+27-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
#include <c10/util/Optional.h>
1111
#include <c10/util/intrusive_ptr.h>
12+
#include <c10/util/numa.h>
13+
#include <c10/util/thread_name.h>
1214

1315
namespace c10 {
1416

@@ -17,7 +19,7 @@ struct Future;
1719
} // namespace ivalue
1820

1921
// TODO: move this to C10 and make it C10_API
20-
class CAFFE2_API TaskThreadPoolBase {
22+
class C10_API TaskThreadPoolBase {
2123
public:
2224
virtual void run(const std::function<void()>& func) = 0;
2325

@@ -36,7 +38,7 @@ class CAFFE2_API TaskThreadPoolBase {
3638
virtual ~TaskThreadPoolBase() noexcept {}
3739
};
3840

39-
class CAFFE2_API ThreadPool : public c10::TaskThreadPoolBase {
41+
class C10_API ThreadPool : public c10::TaskThreadPoolBase {
4042
protected:
4143
struct task_element_t {
4244
bool run_with_id;
@@ -100,8 +102,29 @@ class CAFFE2_API ThreadPool : public c10::TaskThreadPoolBase {
100102
void main_loop(std::size_t index);
101103
};
102104

103-
CAFFE2_API void setNumThreads(size_t v);
105+
C10_API void setNumThreads(size_t v);
104106

105-
CAFFE2_API ThreadPool& global_work_queue();
107+
C10_API TaskThreadPoolBase& global_work_queue();
108+
109+
class C10_API TaskThreadPool : public c10::ThreadPool {
110+
public:
111+
explicit TaskThreadPool(
112+
std::size_t pool_size,
113+
int numa_node_id = -1)
114+
: ThreadPool(pool_size, numa_node_id) {}
115+
116+
// TODO move this to ATen/core/thread_pool.h
117+
void init_thread() override {
118+
setThreadName("CaffeTaskThread");
119+
NUMABind(numa_node_id_);
120+
}
121+
};
122+
123+
C10_DECLARE_SHARED_REGISTRY(
124+
ThreadPoolRegistry,
125+
TaskThreadPoolBase,
126+
int,
127+
int,
128+
bool);
106129

107130
} // namespace c10

c10/macros/Export.h

+14
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@
3030
// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static
3131
// libraries.
3232

33+
// For build systems that do not directly depend on CMake and directly build
34+
// from the source directory (such as Buck), one may not have a cmake_macros.h
35+
// file at all. In this case, the build system is responsible for providing
36+
// correct macro definitions corresponding to the cmake_macros.h.in file.
37+
//
38+
// In such scenarios, one should define the macro
39+
// C10_USING_CUSTOM_GENERATED_MACROS
40+
// to inform this header that it does not need to include the cmake_macros.h
41+
// file.
42+
43+
#ifndef C10_USING_CUSTOM_GENERATED_MACROS
44+
#include "c10/macros/cmake_macros.h"
45+
#endif // C10_USING_CUSTOM_GENERATED_MACROS
46+
3347
#ifdef _WIN32
3448
#if defined(C10_BUILD_SHARED_LIBS)
3549
#define C10_EXPORT __declspec(dllexport)
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
1-
#include "caffe2/utils/thread_name.h"
1+
#include "c10/util/thread_name.h"
22

33
#include <algorithm>
44

55
#if defined(__GLIBC__) && !defined(__APPLE__) && !defined(__ANDROID__)
6-
#define CAFFE2_HAS_PTHREAD_SETNAME_NP
6+
#define C10_HAS_PTHREAD_SETNAME_NP
77
#endif
88

9-
#ifdef CAFFE2_HAS_PTHREAD_SETNAME_NP
9+
#ifdef C10_HAS_PTHREAD_SETNAME_NP
1010
#include <pthread.h>
1111
#endif
1212

13-
namespace caffe2 {
13+
namespace c10 {
1414

1515
void setThreadName(std::string name) {
16-
#ifdef CAFFE2_HAS_PTHREAD_SETNAME_NP
16+
#ifdef C10_HAS_PTHREAD_SETNAME_NP
1717
constexpr size_t kMaxThreadName = 15;
1818
name.resize(std::min(name.size(), kMaxThreadName));
1919

2020
pthread_setname_np(pthread_self(), name.c_str());
2121
#endif
2222
}
2323

24-
} // namespace caffe2
24+
} // namespace c10

c10/util/thread_name.h

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <string>
4+
5+
#include "c10/macros/Export.h"
6+
7+
namespace c10 {
8+
9+
C10_API void setThreadName(std::string name);
10+
11+
} // namespace c10

caffe2/core/net.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <unordered_map>
1010
#include <vector>
1111

12+
#include "c10/core/thread_pool.h"
1213
#include "c10/util/Registry.h"
1314
#include "caffe2/core/blob.h"
1415
#include "caffe2/core/common.h"
@@ -19,7 +20,6 @@
1920
#include "caffe2/core/workspace.h"
2021
#include "caffe2/proto/caffe2_pb.h"
2122
#include "caffe2/utils/simple_queue.h"
22-
#include "caffe2/utils/thread_pool.h"
2323

2424
C10_DECLARE_string(caffe2_override_executor);
2525

caffe2/core/net_async_base.cc

+18-21
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ TaskThreadPoolBase* AsyncNetBase::poolGetter(
151151
std::unique_lock<std::mutex> pools_lock(pools_mutex_);
152152
auto pool = pools[device_id][pool_size];
153153
if (!pool) {
154-
pool = ThreadPoolRegistry()->Create(
154+
pool = c10::ThreadPoolRegistry()->Create(
155155
DeviceTypeName(device_type),
156156
device_id,
157157
pool_size,
@@ -478,26 +478,6 @@ AsyncNetBase::~AsyncNetBase() {
478478
}
479479
}
480480

481-
C10_DEFINE_SHARED_REGISTRY(
482-
ThreadPoolRegistry,
483-
TaskThreadPoolBase,
484-
int,
485-
int,
486-
bool);
487-
488-
C10_REGISTER_CREATOR(
489-
ThreadPoolRegistry,
490-
CPU,
491-
GetAsyncNetThreadPool<TaskThreadPool, PROTO_CPU>);
492-
C10_REGISTER_CREATOR(
493-
ThreadPoolRegistry,
494-
CUDA,
495-
GetAsyncNetThreadPool<TaskThreadPool, PROTO_CUDA>);
496-
C10_REGISTER_CREATOR(
497-
ThreadPoolRegistry,
498-
HIP,
499-
GetAsyncNetThreadPool<TaskThreadPool, PROTO_HIP>);
500-
501481
ExecutionOptions::ExecutionOptions(
502482
const std::shared_ptr<const NetDef>& net_def) {
503483
static const std::string kDag = "dag";
@@ -558,3 +538,20 @@ ExecutionOptions::ExecutionOptions(
558538
}
559539

560540
} // namespace caffe2
541+
542+
namespace c10 {
543+
544+
C10_REGISTER_CREATOR(
545+
ThreadPoolRegistry,
546+
CPU,
547+
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CPU>);
548+
C10_REGISTER_CREATOR(
549+
ThreadPoolRegistry,
550+
CUDA,
551+
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CUDA>);
552+
C10_REGISTER_CREATOR(
553+
ThreadPoolRegistry,
554+
HIP,
555+
caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_HIP>);
556+
557+
} // namespace c10

caffe2/core/net_async_base.h

+1-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
22
#define CAFFE2_CORE_NET_ASYNC_BASE_H_
33

4+
#include "c10/core/thread_pool.h"
45
#include "c10/util/Registry.h"
56
#include "caffe2/core/common.h"
67
#include "caffe2/core/net.h"
@@ -12,7 +13,6 @@
1213
#include "caffe2/proto/caffe2_pb.h"
1314
#include "caffe2/proto/prof_dag.pb.h"
1415
#include "caffe2/utils/proto_utils.h"
15-
#include "caffe2/utils/thread_pool.h"
1616

1717
C10_DECLARE_int(caffe2_streams_per_gpu);
1818
C10_DECLARE_int(caffe2_net_async_max_gpus);
@@ -167,13 +167,6 @@ class CAFFE2_API AsyncNetBase : public NetBase {
167167
friend class tracing::Tracer;
168168
};
169169

170-
C10_DECLARE_SHARED_REGISTRY(
171-
ThreadPoolRegistry,
172-
TaskThreadPoolBase,
173-
int,
174-
int,
175-
bool);
176-
177170
class AsyncNetExecutorHelper : public ExecutorHelper {
178171
public:
179172
explicit AsyncNetExecutorHelper(AsyncNetBase* net) : net_(net) {}

caffe2/core/net_parallel.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ TaskThreadPoolBase* ParallelNet::poolGetter(
132132
std::unique_lock<std::mutex> pools_lock(pools_mutex_);
133133
auto pool = pools[device_id][pool_size];
134134
if (!pool) {
135-
pool = ThreadPoolRegistry()->Create(
135+
pool = c10::ThreadPoolRegistry()->Create(
136136
DeviceTypeName(device_type),
137137
device_id,
138138
pool_size,

caffe2/image/image_input_op.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
#include <iostream>
88
#include <algorithm>
99

10+
#include "c10/core/thread_pool.h"
1011
#include "caffe2/core/common.h"
1112
#include "caffe2/core/db.h"
13+
#include "caffe2/image/transform_gpu.h"
14+
#include "caffe2/operators/prefetch_op.h"
1215
#include "caffe2/proto/caffe2_legacy.pb.h"
1316
#include "caffe2/utils/cast.h"
1417
#include "caffe2/utils/math.h"
15-
#include "caffe2/utils/thread_pool.h"
16-
#include "caffe2/operators/prefetch_op.h"
17-
#include "caffe2/image/transform_gpu.h"
1818

1919
namespace caffe2 {
2020

caffe2/utils/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ list(APPEND Caffe2_CPU_SRCS
1313
utils/signal_handler.cc
1414
utils/smart_tensor_printer.cc
1515
utils/string_utils.cc
16-
utils/thread_name.cc
1716
utils/threadpool/ThreadPool.cc)
1817

1918
# ---[ threadpool/pthreadpool* is a local modification of the NNPACK

caffe2/utils/thread_name.h

-11
This file was deleted.

caffe2/utils/thread_pool.h

-26
This file was deleted.

caffe2/utils/threadpool/WorkersPool.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
#include <atomic>
44
#include <condition_variable>
55
#include <thread>
6+
#include "c10/util/thread_name.h"
67
#include "caffe2/core/common.h"
78
#include "caffe2/core/logging.h"
8-
#include "caffe2/utils/thread_name.h"
99

1010
#if defined(_MSC_VER)
1111
#include <intrin.h>
@@ -263,7 +263,7 @@ class alignas(kGEMMLOWPCacheLineSize) Worker {
263263

264264
// Thread entry point.
265265
void ThreadFunc() {
266-
setThreadName("CaffeWorkersPool");
266+
c10::setThreadName("CaffeWorkersPool");
267267
ChangeState(State::Ready);
268268

269269
// Thread main loop

caffe2/video/video_input_op.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
#include <random>
77
#include <string>
88

9+
#include <c10/core/thread_pool.h>
910
#include <caffe2/core/db.h>
1011
#include <caffe2/core/logging.h>
1112
#include <caffe2/operators/prefetch_op.h>
1213
#include <caffe2/utils/math.h>
13-
#include <caffe2/utils/thread_pool.h>
1414
#include <caffe2/video/video_io.h>
1515

1616
namespace caffe2 {

torch/csrc/jit/interpreter.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <ATen/core/ivalue.h>
1414
#include <torch/csrc/jit/operator.h>
1515
#include <torch/csrc/jit/script/jit_exception.h>
16-
#include <ATen/core/thread_pool.h>
16+
#include <c10/core/thread_pool.h>
1717

1818
#include <exception>
1919
#include <iostream>

torch/csrc/jit/register_prim_ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <ATen/ExpandUtils.h>
1515
#include <ATen/WrapDimUtils.h>
1616
#include <ATen/core/ivalue.h>
17-
#include <ATen/core/thread_pool.h>
17+
#include <c10/core/thread_pool.h>
1818
#include <c10/util/SmallVector.h>
1919

2020
#include <algorithm>

0 commit comments

Comments
 (0)