Skip to content

Commit c7d6144

Browse files
aparna-aketifacebook-github-bot
authored andcommitted
Modifying DPLossFastGradientClipping to add support for generative tasks with ghost clipping (#722)
Summary: Pull Request resolved: #722 Generative tasks for NLP output predictions of shape (B,T,C) i.e., (batch_size, sequence_length, vocab_size). To compute the cross-entropy loss in this case, usually the predictions are reshaped to (BxT, C) and targets to (BxT). This creates an issue with Ghost Clipping per sample loss computation as BxT is seen as the batch_size. In particular, the current implementation of Ghost Clipping results in loss_per_sample, coeff variables to have a shape of BxT and B respectively. This causes a shape mismatch error. This diff fixes that error by collapsing the loss_per_sample variable to shape B i.e., the loss across the sequence_length dim is averaged/summed. Reviewed By: EnayatUllah Differential Revision: D68047256 fbshipit-source-id: ad7614e2cdba59869d762d810a14b96b465ee513
1 parent 9741fe2 commit c7d6144

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

opacus/utils/fast_gradient_clipping_utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def backward(self):
6868
reduced_loss.backward(retain_graph=True)
6969
self.optimizer.zero_grad()
7070
coeff = self.module.get_clipping_coef()
71-
second_loss_per_sample = coeff * self.loss_per_sample
71+
second_loss_per_sample = (
72+
coeff.to(self.loss_per_sample.device) * self.loss_per_sample
73+
)
7274
second_loss = torch.sum(second_loss_per_sample)
7375
self.module.disable_hooks()
7476
second_loss.backward()
@@ -104,15 +106,27 @@ def __init__(
104106
self.loss_reduction = loss_reduction
105107
self.criterion.reduction = "none"
106108

107-
def __call__(self, input, target) -> DPTensorFastGradientClipping:
109+
def __call__(self, input, target, shape=None) -> DPTensorFastGradientClipping:
108110
"""
109111
Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping
110112
"""
111113

112-
loss_per_sample = self.criterion(
113-
input,
114-
target,
115-
)
114+
loss_per_sample = self.criterion(input, target)
115+
116+
if shape is not None and loss_per_sample.shape[0] == shape[0] * shape[1]:
117+
# Note that the privacy unit for generative NLP tasks is per sequence.
118+
# The shape variable is the shape of the logits before flattening i.e., [batch_size, sequence_lenght, vocab_size].
119+
# This variable is necessary for ghost clipping to work with generative NLP tasks.
120+
loss_per_sample = loss_per_sample.view(shape[0], shape[1]) # BxT
121+
if self.loss_reduction == "mean":
122+
loss_per_sample = loss_per_sample.mean(dim=1) # B
123+
elif self.loss_reduction == "sum":
124+
loss_per_sample = loss_per_sample.sum(dim=1) # B
125+
else:
126+
raise ValueError(
127+
f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported"
128+
)
129+
116130
return DPTensorFastGradientClipping(
117131
self.module, self.optimizer, loss_per_sample, self.loss_reduction
118132
)

0 commit comments

Comments
 (0)