Skip to content

Commit a50ba7e

Browse files
Ailing Zhangfacebook-github-bot
authored andcommitted
specialized CUDA impl for dropout in AD (pytorch#17756)
Summary: In aten we have a _fused_dropout implementation for CUDA case. As ngimel suggested if we discard it in JIT AD, it hurts performance. It doesn't seem ideal to include backend specific implementation in AD, but this is helpful to prevent performance regression atm. Pull Request resolved: pytorch#17756 Differential Revision: D14368999 Pulled By: ailzhang fbshipit-source-id: 9a371c5020f630e8f6e496849ec9772b6f196169
1 parent 9a15341 commit a50ba7e

File tree

3 files changed

+83
-15
lines changed

3 files changed

+83
-15
lines changed

test/test_jit.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,27 @@ def test_dropout(self):
13411341
m = self.createScriptModuleFromGraph(trace)
13421342
self.assertEqual(outputs, m(*inputs))
13431343

1344+
@unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
1345+
def test_dropout_cuda(self):
1346+
# Dropout AD is dispatched to _fused_dropout in CUDA case,
1347+
# which is not included in TestJitGeneratedFunctional
1348+
x = torch.ones(4, 4).cuda().requires_grad_()
1349+
1350+
@torch.jit.script
1351+
def func(x):
1352+
return torch.nn.functional.dropout(x)
1353+
1354+
with freeze_rng_state():
1355+
out_ref = torch.nn.functional.dropout(x)
1356+
grad_ref = torch.autograd.grad(out_ref.sum(), x)
1357+
1358+
with freeze_rng_state():
1359+
out = func(x)
1360+
grad = torch.autograd.grad(out.sum(), x)
1361+
1362+
self.assertEqual(out, out_ref)
1363+
self.assertEqual(grad, grad_ref)
1364+
13441365
def test_conv(self):
13451366
x = torch.ones(20, 16, 50, 40)
13461367
trace, outputs, inputs = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True)

torch/csrc/jit/passes/shape_analysis.cpp

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ bool isValidReturnForRunning(Value* v) {
4848
v->type()->isSubtypeOf(NumberType::get());
4949
}
5050

51+
bool containsTensorType(const TypePtr& t) {
52+
auto n_contained = t->containedTypes().size();
53+
if (n_contained == 1) {
54+
return t->containedTypes().at(0)->isSubtypeOf(TensorType::get());
55+
} else if (n_contained > 1) {
56+
return std::any_of(
57+
t->containedTypes().begin(),
58+
t->containedTypes().end(),
59+
containsTensorType);
60+
}
61+
return false;
62+
}
63+
5164
class ShapePropagator {
5265
public:
5366
explicit ShapePropagator(std::shared_ptr<Graph> graph) : aliasDb_(graph) {
@@ -298,6 +311,18 @@ class ShapePropagator {
298311
return true;
299312
}
300313

314+
// If there's no Tensor in outputs, e.g float / float,
315+
// we don't need to propagate shape.
316+
bool DoesntRefineOutputs(Node* node) {
317+
auto outputs = node->outputs();
318+
for (auto& out : outputs) {
319+
if (containsTensorType(out->type())) {
320+
return false;
321+
}
322+
}
323+
return true;
324+
}
325+
301326
bool PropagateShapeOnNodeByRunningIt(Node* node) {
302327
if (!canPropagateShapeByRunningIt(node))
303328
return false;
@@ -534,6 +559,10 @@ class ShapePropagator {
534559
return;
535560
}
536561

562+
if (DoesntRefineOutputs(node)) {
563+
return;
564+
}
565+
537566
if (PropagateShapeOnNodeByRunningIt(node)) {
538567
return;
539568
}
@@ -1074,26 +1103,25 @@ class ShapePropagator {
10741103
at::optional<IValue> maybe_layout_option = node->get(attr::layout);
10751104
if (!maybe_layout_option)
10761105
return {};
1077-
auto layout = (maybe_layout_option->isNone()
1078-
? at::kStrided
1079-
: maybe_layout_option->toLayout());
1106+
auto layout =
1107+
(maybe_layout_option->isNone() ? at::kStrided
1108+
: maybe_layout_option->toLayout());
10801109

10811110
at::optional<IValue> maybe_device_option = node->get(attr::device);
10821111
if (!maybe_device_option)
10831112
return {};
1084-
auto device = (maybe_device_option->isNone()
1085-
? at::kCPU
1086-
: maybe_device_option->toDevice());
1113+
auto device =
1114+
(maybe_device_option->isNone() ? at::kCPU
1115+
: maybe_device_option->toDevice());
10871116

10881117
at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
10891118
if (!maybe_dtype_option)
10901119
return {};
1091-
auto dtype = (maybe_dtype_option->isNone()
1092-
? at::kFloat
1093-
: maybe_dtype_option->toScalarType());
1120+
auto dtype =
1121+
(maybe_dtype_option->isNone() ? at::kFloat
1122+
: maybe_dtype_option->toScalarType());
10941123

1095-
return {DimensionedTensorType::create(
1096-
dtype, device, dim)};
1124+
return {DimensionedTensorType::create(dtype, device, dim)};
10971125
};
10981126

10991127
// Requirements:

torch/csrc/jit/symbolic_script.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -691,15 +691,34 @@ const std::vector<std::string> functions = {
691691
692692
return output, backward
693693
694+
def AD_fused_dropout_backward(grad,
695+
mask,
696+
p1m: float):
697+
p1r = 1. / p1m
698+
if grad.requires_grad:
699+
grad_input = grad * (mask.type_as(grad) * p1r)
700+
else:
701+
grad_input = torch._masked_scale(grad, mask, p1r)
702+
return grad_input
703+
694704
def dropout(input,
695705
p: float,
696706
train: bool):
697-
mask = torch.empty_like(input)
698-
mask.bernoulli_(1 - p)
699-
res = mask * input / (1.0 - p)
707+
use_cuda = input.is_cuda
708+
# CUDA has a fused dropout implementation
709+
p1m = 1. - p
710+
if use_cuda:
711+
res, mask = torch._fused_dropout(input, p1m)
712+
else:
713+
mask = torch.empty_like(input)
714+
mask.bernoulli_(p1m)
715+
res = mask * input / p1m
700716
701717
def backward(grad_output):
702-
grad_input = grad_output * mask / (1.0 - p)
718+
if use_cuda:
719+
grad_input = AD_fused_dropout_backward(grad_output, mask, p1m)
720+
else:
721+
grad_input = grad_output * mask / p1m
703722
return grad_input, None, None
704723
return res, backward
705724

0 commit comments

Comments
 (0)