Skip to content

Commit 48db74e

Browse files
Yangqing Jiafacebook-github-bot
Yangqing Jia
authored andcommitted
net_simple_refcount type to help experimentation with dynamic allocation. (pytorch#13370)
Summary: Pull Request resolved: pytorch#13370 This diff adds a new net type (simple_refcount) that does one thing: for all intermediate results produced by a net, it will keep refcount about internal usage, and when it finishes its consumption, the net will delete the blob content to mimic the case of dynamic allocation. In fact, this would also be the behavior when we go functional: anything that is not explicitly marked as input or output will be up to the executor for lifetime management. See the comments in net_simple_refcount.cc for details. Reviewed By: dzhulgakov Differential Revision: D12855489 fbshipit-source-id: 594a47a786305d595fd505b6700864dd1d9c72aa
1 parent 479b826 commit 48db74e

File tree

3 files changed

+211
-0
lines changed

3 files changed

+211
-0
lines changed

caffe2/core/net_simple_refcount.cc

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include "caffe2/core/net_simple_refcount.h"
2+
#include "caffe2/core/net.h"
3+
4+
#include <iostream>
5+
#include <set>
6+
#include <unordered_map>
7+
#include <unordered_set>
8+
9+
#include "caffe2/core/operator.h"
10+
#include "caffe2/core/static_tracepoint.h"
11+
#include "caffe2/core/timer.h"
12+
#include "caffe2/proto/caffe2_pb.h"
13+
#include "caffe2/utils/proto_utils.h"
14+
15+
namespace caffe2 {
16+
17+
SimpleRefCountNet::SimpleRefCountNet(
18+
const std::shared_ptr<const NetDef>& net_def,
19+
Workspace* ws)
20+
: SimpleNet(net_def, ws) {
21+
VLOG(1) << "Constructing SimpleRefCountNet " << net_def->name();
22+
// Construct the "to delete" list.
23+
delete_list_.resize(net_def->op_size());
24+
25+
std::map<string, int> last_consumed_at;
26+
std::set<string> created_by_me;
27+
// For each opeartor
28+
for (int idx = 0; idx < net_def->op_size(); ++idx) {
29+
const auto& op_def = net_def->op(idx);
30+
for (const string& in_name : op_def.input()) {
31+
last_consumed_at[in_name] = idx;
32+
}
33+
for (const string& out_name : op_def.output()) {
34+
created_by_me.insert(out_name);
35+
}
36+
}
37+
// We do not delete any operator that is not produced by the net, and
38+
// any operator that is marked as external_output. Any blob that is not
39+
// consumed won't be in the last_consumed_at map, so we don't need to
40+
// do anything special.
41+
for (auto& kv : last_consumed_at) {
42+
if (!created_by_me.count(kv.first)) {
43+
kv.second = -1;
44+
}
45+
}
46+
for (const string& name : net_def->external_output()) {
47+
last_consumed_at[name] = -1;
48+
}
49+
// Set up the delete list.
50+
for (auto& kv : last_consumed_at) {
51+
if (kv.second > 0) {
52+
delete_list_[kv.second].push_back(ws->GetBlob(kv.first));
53+
VLOG(1) << "NetSimpleRefCountNet: will delete " << kv.first
54+
<< " at operator #" << kv.second;
55+
}
56+
}
57+
}
58+
59+
bool SimpleRefCountNet::Run() {
60+
StartAllObservers();
61+
VLOG(1) << "Running net " << name_;
62+
for (int op_id = 0; op_id < operators_.size(); ++op_id) {
63+
auto& op = operators_[op_id];
64+
VLOG(1) << "Running operator " << op->debug_def().name() << "("
65+
<< op->debug_def().type() << ").";
66+
bool res = op->Run();
67+
if (!res) {
68+
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
69+
return false;
70+
}
71+
for (Blob* blob : delete_list_[op_id]) {
72+
blob->Reset();
73+
}
74+
}
75+
StopAllObservers();
76+
return true;
77+
}
78+
79+
REGISTER_NET(simple_refcount, SimpleRefCountNet);
80+
81+
} // namespace caffe2

caffe2/core/net_simple_refcount.h

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#ifndef CAFFE2_CORE_NET_SIMPLE_REFCOUNT_H_
2+
#define CAFFE2_CORE_NET_SIMPLE_REFCOUNT_H_
3+
4+
#include <vector>
5+
6+
#include "c10/util/Registry.h"
7+
#include "caffe2/core/common.h"
8+
#include "caffe2/core/logging.h"
9+
#include "caffe2/core/net.h"
10+
#include "caffe2/core/net_simple.h"
11+
#include "caffe2/core/tensor.h"
12+
#include "caffe2/core/workspace.h"
13+
#include "caffe2/proto/caffe2_pb.h"
14+
15+
namespace caffe2 {
16+
17+
// SimpleRefcountNet is an implementation that adds an additional abstraction
18+
// on top of SimpleRefCountNet: it tracks all the tensors and for those that are
19+
// considered internal/temporary, delete them once their refcount go to zero.
20+
// In the context of a simple static run, this can be carried out during
21+
// construction time: we will do a pass through the network and track what
22+
// blobs we need to do reset on, after the execution of every op.
23+
//
24+
// To identify which blob is considered temporary, we employ the following
25+
// strategy: any blob that is
26+
// (1) consumed but not produced by ops in the net, or
27+
// (2) produced but not consumed by ops in the net, or
28+
// (3) is marked as external_output in the protobuf
29+
// will NOT be considered temporary.
30+
//
31+
// In the long run, we should design proper functional interfaces so that
32+
// nets are less imperative and more functional.
33+
//
34+
// Also, for now, SimpleRefCountNet should only be used for benchmarking
35+
// purposes and not product use, since it is not going to provide better
36+
// performance gain, and is implicitly incompatible with the contract that
37+
// earlier Nets expose - that all intermediate blobs are visible to the users.
38+
class SimpleRefCountNet final : public SimpleNet {
39+
public:
40+
SimpleRefCountNet(
41+
const std::shared_ptr<const NetDef>& net_def,
42+
Workspace* ws);
43+
44+
protected:
45+
bool Run() override;
46+
47+
using SimpleNet::operators_;
48+
49+
private:
50+
// The list of blobs to delete when each operator finishes its run.
51+
// This will be populated during construction time.
52+
vector<vector<Blob*>> delete_list_;
53+
54+
C10_DISABLE_COPY_AND_ASSIGN(SimpleRefCountNet);
55+
};
56+
57+
} // namespace caffe2
58+
59+
#endif // CAFFE2_CORE_NET_SIMPLE_REFCOUNT_H_
+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#include <gtest/gtest.h>
2+
#include "c10/util/StringUtil.h"
3+
#include "caffe2/core/net.h"
4+
#include "caffe2/core/net_async_scheduling.h"
5+
#include "caffe2/core/net_dag.h"
6+
#include "caffe2/core/operator.h"
7+
#include "caffe2/core/scope_guard.h"
8+
9+
#include <google/protobuf/text_format.h>
10+
11+
namespace caffe2 {
12+
13+
namespace {
14+
15+
// A net test dummy op that does nothing but scaffolding. Here, we
16+
// inherit from OperatorBase because we instantiate on both CPU and
17+
// GPU. In general, you want to only inherit from Operator<Context>.
18+
class NetSimpleRefCountTestOp final : public Operator<CPUContext> {
19+
public:
20+
NetSimpleRefCountTestOp(const OperatorDef& operator_def, Workspace* ws)
21+
: Operator<CPUContext>(operator_def, ws) {}
22+
USE_OPERATOR_FUNCTIONS(CPUContext);
23+
24+
bool RunOnDevice() override {
25+
const int32_t& input = OperatorBase::Input<int32_t>(0);
26+
int32_t* output = OperatorBase::Output<int32_t>(0);
27+
*output = input + 1;
28+
return true;
29+
}
30+
};
31+
32+
REGISTER_CPU_OPERATOR(NetSimpleRefCountTest, NetSimpleRefCountTestOp);
33+
34+
OPERATOR_SCHEMA(NetSimpleRefCountTest).NumInputs(1).NumOutputs(1);
35+
36+
TEST(NetSimpleRefCountTest, TestCorrectness) {
37+
Workspace ws;
38+
*(ws.CreateBlob("a")->GetMutable<int32_t>()) = 1;
39+
NetDef net_def;
40+
net_def.set_type("simple_refcount");
41+
net_def.add_op()->CopyFrom(
42+
CreateOperatorDef("NetSimpleRefCountTest", "", {"a"}, {"b"}));
43+
net_def.add_op()->CopyFrom(
44+
CreateOperatorDef("NetSimpleRefCountTest", "", {"b"}, {"c"}));
45+
net_def.add_op()->CopyFrom(
46+
CreateOperatorDef("NetSimpleRefCountTest", "", {"b"}, {"d"}));
47+
net_def.add_op()->CopyFrom(
48+
CreateOperatorDef("NetSimpleRefCountTest", "", {"c"}, {"e"}));
49+
// After execution, what should look like is:
50+
// a = 1
51+
// b = deallocated
52+
// c = deallocated
53+
// d = 3
54+
// e = 4
55+
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
56+
net->Run();
57+
// Note on ASSERT vs EXPECT: ASSERT will quit directly if condition not
58+
// met, which is why we guard IsType<> calls with ASSERT so that the
59+
// subsequent Get() calls do not product an exception.
60+
ASSERT_TRUE(ws.GetBlob("a")->IsType<int32_t>());
61+
EXPECT_EQ(ws.GetBlob("a")->Get<int32_t>(), 1);
62+
EXPECT_EQ(ws.GetBlob("b")->GetRaw(), nullptr);
63+
EXPECT_EQ(ws.GetBlob("c")->GetRaw(), nullptr);
64+
ASSERT_TRUE(ws.GetBlob("d")->IsType<int32_t>());
65+
EXPECT_EQ(ws.GetBlob("d")->Get<int32_t>(), 3);
66+
ASSERT_TRUE(ws.GetBlob("e")->IsType<int32_t>());
67+
EXPECT_EQ(ws.GetBlob("e")->Get<int32_t>(), 4);
68+
}
69+
70+
} // namespace
71+
} // namespace caffe2

0 commit comments

Comments
 (0)