Skip to content

Commit 046672e

Browse files
orionrfacebook-github-bot
authored andcommitted
Set proper scope on nodes added by JIT (pytorch#12400)
Summary: In order to support tensorboardX and other visualization tools, we need to make sure a non-empty scope is set on all nodes added by the JIT. This attempts to do this, but is still a WIP. This is a new version of pytorch#10749 Pull Request resolved: pytorch#12400 Reviewed By: ezyang Differential Revision: D10224380 Pulled By: orionr fbshipit-source-id: d1bccd0eee9ef7c4354112c6a39a5987bfac2994
1 parent cf235e0 commit 046672e

File tree

10 files changed

+163
-91
lines changed

10 files changed

+163
-91
lines changed

torch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ set(TORCH_SRCS
187187
${TORCH_SRC_DIR}/csrc/jit/passes/pretty_print.cpp
188188
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
189189
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
190+
${TORCH_SRC_DIR}/csrc/jit/scope.cpp
190191
${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
191192
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
192193
${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp

torch/csrc/jit/constants.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
namespace torch { namespace jit {
77

88
// IValue -> Constant node
9-
Value* insertConstant(Graph& g, IValue val, c10::optional<SourceRange> loc) {
9+
Value* insertConstant(
10+
Graph& g,
11+
IValue val,
12+
c10::optional<SourceRange> loc,
13+
c10::optional<ScopePtr> scope) {
1014
Node * n = g.create(prim::Constant);
1115
if(val.isTensor()) {
1216
at::Tensor ref = std::move(val).toTensor();
@@ -53,6 +57,8 @@ Value* insertConstant(Graph& g, IValue val, c10::optional<SourceRange> loc) {
5357
}
5458
if(loc)
5559
n->setSourceLocation(std::make_shared<SourceRange>(*loc));
60+
if(scope)
61+
n->setScope(*scope);
5662
return g.insertNode(n)->output();
5763
}
5864

torch/csrc/jit/constants.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include "torch/csrc/jit/ivalue.h"
3+
#include "torch/csrc/jit/scope.h"
34
#include "torch/csrc/jit/source_range.h"
45
#include "torch/csrc/WindowsTorchApiMacro.h"
56

@@ -22,7 +23,9 @@ struct TORCH_API constant_not_supported_error : public std::runtime_error {
2223
TORCH_API Value* insertConstant(
2324
Graph& g,
2425
IValue val,
25-
c10::optional<SourceRange> loc = c10::nullopt);
26+
c10::optional<SourceRange> loc = c10::nullopt,
27+
c10::optional<ScopePtr> scope = c10::nullopt);
28+
2629

2730
//////////////////////////////////////////////////////////////////////////////////
2831
// Helper for retrieving constants

torch/csrc/jit/ir.cpp

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -196,33 +196,6 @@ void Graph::dumpPretty() {
196196
PrettyPrint(std::cout, *this);
197197
}
198198

199-
ScopePtr Scope::push(Symbol name) {
200-
return c10::make_intrusive<Scope>(intrusive_from_this(), name);
201-
}
202-
203-
ScopePtr Scope::getRoot() {
204-
ScopePtr current = intrusive_from_this();
205-
while (current->parent_) {
206-
current = current->parent_;
207-
}
208-
return current;
209-
}
210-
211-
std::string Scope::namesFromRoot(const std::string& separator) const {
212-
// TODO: I think the answer is we shouldn't have used Symbol here
213-
std::string out = this->name_.toUnqualString();
214-
if (this->isRoot()) {
215-
return out;
216-
}
217-
ScopePtr parent = this->parent_;
218-
while (!parent->isRoot()) {
219-
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
220-
out = std::string(parent->name_.toUnqualString()) + separator + out;
221-
parent = parent->parent_;
222-
}
223-
return out;
224-
}
225-
226199
static void checkSameDevice(const Node* node) {
227200
bool has_device = false;
228201
int device;
@@ -1125,8 +1098,9 @@ Node* Graph::createClone(Node * n, std::function<Value*(Value*)> value_map, bool
11251098

11261099
Value* Graph::insertConstant(
11271100
IValue val,
1128-
c10::optional<SourceRange> loc) {
1129-
return jit::insertConstant(*this, std::move(val), loc);
1101+
c10::optional<SourceRange> loc,
1102+
c10::optional<ScopePtr> scope) {
1103+
return jit::insertConstant(*this, std::move(val), loc, scope);
11301104
}
11311105

11321106
Value* Graph::insertDummyWorld() {

torch/csrc/jit/ir.h

Lines changed: 3 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "torch/csrc/jit/graph_node_list.h"
77
#include "torch/csrc/jit/interned_strings.h"
88
#include "torch/csrc/jit/resource_guard.h"
9+
#include "torch/csrc/jit/scope.h"
910
#include "torch/csrc/jit/source_location.h"
1011
#include "torch/csrc/jit/source_range.h"
1112
#include "torch/csrc/jit/constants.h"
@@ -97,61 +98,6 @@ struct Use {
9798
// If you are looking for "use induced by an input", it's best to use
9899
// findUseForInput() to get it.
99100

100-
101-
// Scope is a node of a trie that represents the tree of nested scopes.
102-
// Individual scopes are pushed and popped from Graph, which holds a
103-
// pointer to the current scope. Each Node in Graph holds a pointer
104-
// to the scope that was current when the node was created.
105-
// The trie never needs to shrink, it only grows until it is disposed
106-
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
107-
// will always be valid as long as Graph is alive.
108-
struct Scope;
109-
using ScopePtr = c10::intrusive_ptr<Scope>;
110-
111-
struct TORCH_API Scope : public c10::intrusive_ptr_target {
112-
private:
113-
ScopePtr parent_;
114-
Symbol name_;
115-
ScopePtr intrusive_from_this() {
116-
c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
117-
// from a raw `this` pointer
118-
// so we need to bump the refcount
119-
// to account for this ownership
120-
return c10::intrusive_ptr<Scope>::reclaim(this);
121-
}
122-
public:
123-
Scope() {
124-
name_ = Symbol::scope("");
125-
}
126-
Scope(ScopePtr parent, Symbol name) {
127-
name_ = name;
128-
parent_ = parent;
129-
}
130-
ScopePtr push(Symbol name);
131-
132-
ScopePtr parent() {
133-
if (!parent_) {
134-
throw std::runtime_error("Cannot get parent from Scope with no parent");
135-
}
136-
return parent_;
137-
}
138-
bool isRoot() const {
139-
return !parent_;
140-
}
141-
bool isBlank() const {
142-
static const Symbol blank = Symbol::scope("");
143-
return isRoot() && name() == blank;
144-
}
145-
146-
ScopePtr getRoot();
147-
148-
Symbol name() const {
149-
return name_;
150-
}
151-
152-
std::string namesFromRoot(const std::string& separator="/") const;
153-
};
154-
155101
// the list types are intentionally simple, but we type-def
156102
// them here so if we need to change them, refactoring will be easier
157103
using node_list = std::vector<Node*>;
@@ -868,7 +814,8 @@ friend struct Block;
868814

869815
TORCH_API Value* insertConstant(
870816
IValue val,
871-
c10::optional<SourceRange> loc = c10::nullopt);
817+
c10::optional<SourceRange> loc = c10::nullopt,
818+
c10::optional<ScopePtr> scope = c10::nullopt);
872819

873820
TORCH_API Value* insertDummyWorld();
874821

torch/csrc/jit/passes/erase_number_types.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ static void EraseNumberTypesOnBlock(Block* block) {
1717
it->output()->type()->isSubtypeOf(BoolType::get())) {
1818
auto s = *constant_as<at::Scalar>(it->output());
1919
WithInsertPoint guard(*it);
20-
Value* r = block->owningGraph()->insertConstant(scalar_to_tensor(s));
20+
Value* r = block->owningGraph()->insertConstant(
21+
scalar_to_tensor(s), c10::nullopt, it->scope());
2122
it->output()->replaceAllUsesWith(r);
2223
}
2324
} break;

torch/csrc/jit/passes/onnx.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,10 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
5858
// Unfortunately, they are on the hook for all internal nodes
5959
// (though in practice, the types are not computed.)
6060
outputs[i]->setType(old->type());
61-
// Copy over source location information to all nodes created by
62-
// the symbolic
61+
// Copy over source location and scope information to all nodes
62+
// created by the symbolic
6363
outputs[i]->node()->setSourceLocation(node->getSourceLocation());
64+
outputs[i]->node()->setScope(node->scope());
6465
env[old] = outputs[i];
6566
} else {
6667
// Null output means that the ONNX op doesn't have outputs corresponding

torch/csrc/jit/scope.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "ir.h"
2+
3+
4+
#include "torch/csrc/jit/operator.h"
5+
#include "torch/csrc/autograd/function.h"
6+
#include "torch/csrc/jit/constants.h"
7+
#include "torch/csrc/jit/assertions.h"
8+
#include "torch/csrc/jit/script/compiler.h"
9+
#include "torch/csrc/jit/passes/pretty_print.h"
10+
11+
#include <iostream>
12+
#include <unordered_map>
13+
#include <unordered_set>
14+
#include <set>
15+
#include <stack>
16+
#include <sstream>
17+
#include <algorithm>
18+
#include <string>
19+
20+
namespace torch { namespace jit {
21+
22+
ScopePtr Scope::push(Symbol name) {
23+
return c10::make_intrusive<Scope>(intrusive_from_this(), name);
24+
}
25+
26+
ScopePtr Scope::getRoot() {
27+
ScopePtr current = intrusive_from_this();
28+
while (current->parent_) {
29+
current = current->parent_;
30+
}
31+
return current;
32+
}
33+
34+
size_t Scope::getDepth() {
35+
size_t d = 1;
36+
ScopePtr current = intrusive_from_this();
37+
while (current->parent_) {
38+
current = current->parent_;
39+
d += 1;
40+
}
41+
return d;
42+
}
43+
44+
std::string Scope::namesFromRoot(const std::string& separator) const {
45+
// TODO: I think the answer is we shouldn't have used Symbol here
46+
std::string out = this->name_.toUnqualString();
47+
if (this->isRoot()) {
48+
return out;
49+
}
50+
ScopePtr parent = this->parent_;
51+
while (!parent->isRoot()) {
52+
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
53+
out = std::string(parent->name_.toUnqualString()) + separator + out;
54+
parent = parent->parent_;
55+
}
56+
return out;
57+
}
58+
59+
}} // namespace torch::jit

torch/csrc/jit/scope.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#pragma once
2+
#include "torch/csrc/jit/interned_strings.h"
3+
#include "torch/csrc/jit/assertions.h"
4+
#include "torch/csrc/WindowsTorchApiMacro.h"
5+
#include "c10/macros/Macros.h"
6+
7+
#include <memory>
8+
9+
namespace torch {
10+
namespace jit {
11+
12+
// Scope is a node of a trie that represents the tree of nested scopes.
13+
// Individual scopes are pushed and popped from Graph, which holds a
14+
// pointer to the current scope. Each Node in Graph holds a pointer
15+
// to the scope that was current when the node was created.
16+
// The trie never needs to shrink, it only grows until it is disposed
17+
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
18+
// will always be valid as long as Graph is alive.
19+
struct Scope;
20+
using ScopePtr = c10::intrusive_ptr<Scope>;
21+
22+
struct TORCH_API Scope : public c10::intrusive_ptr_target {
23+
private:
24+
ScopePtr parent_;
25+
Symbol name_;
26+
ScopePtr intrusive_from_this() {
27+
c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
28+
// from a raw `this` pointer
29+
// so we need to bump the refcount
30+
// to account for this ownership
31+
return c10::intrusive_ptr<Scope>::reclaim(this);
32+
}
33+
public:
34+
Scope() {
35+
name_ = Symbol::scope("");
36+
}
37+
Scope(ScopePtr parent, Symbol name) {
38+
name_ = name;
39+
parent_ = parent;
40+
}
41+
ScopePtr push(Symbol name);
42+
43+
ScopePtr parent() {
44+
if (!parent_) {
45+
throw std::runtime_error("Cannot get parent from Scope with no parent");
46+
}
47+
return parent_;
48+
}
49+
bool isRoot() const {
50+
return !parent_;
51+
}
52+
bool isBlank() const {
53+
static const Symbol blank = Symbol::scope("");
54+
return isRoot() && name() == blank;
55+
}
56+
57+
ScopePtr getRoot();
58+
59+
size_t getDepth();
60+
61+
Symbol name() const {
62+
return name_;
63+
}
64+
65+
std::string namesFromRoot(const std::string& separator="/") const;
66+
};
67+
68+
} // namespace jit
69+
} // namespace torch

torch/csrc/jit/symbolic_variable.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,18 @@ struct SymbolicVariable {
3232
if(g == nullptr) {
3333
g = inputs.at(0).value()->owningGraph();
3434
}
35-
Node * n = g->insertNode(g->create(kind, num_outputs));
35+
Node* n = g->insertNode(g->create(kind, num_outputs));
36+
size_t max_depth = 0;
37+
ScopePtr s;
38+
for(auto n : inputs) {
39+
size_t d = n.value()->node()->scope()->getDepth();
40+
if(d > max_depth) {
41+
max_depth = d;
42+
s = n.value()->node()->scope();
43+
}
44+
}
45+
n->setScope(s);
46+
3647
for(auto i : inputs) {
3748
n->addInput(i.value());
3849
}

0 commit comments

Comments
 (0)