Skip to content

Commit d2ab632

Browse files
gushiqiaogushiqiao
andauthored
Support gptq dp (#248)
Co-authored-by: gushiqiao <[email protected]>
1 parent 2f3e425 commit d2ab632

File tree

1 file changed

+8
-0
lines changed
  • llmc/compression/quantization

1 file changed

+8
-0
lines changed

llmc/compression/quantization/gptq.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import copy
22
import functools
33
import math
4+
import os
45
from abc import ABCMeta, abstractmethod
56
from collections import defaultdict
67

78
import torch
9+
import torch.distributed as dist
810
import torch.nn as nn
911
import transformers
1012
from loguru import logger
@@ -249,6 +251,7 @@ def cache_input_hook(self, m, inp, out, name, feat_dict):
249251

250252
@torch.no_grad()
251253
def add_batch(self, layer, name, inp, out):
254+
world_size = int(os.environ['WORLD_SIZE'])
252255
if len(inp.shape) == 2:
253256
inp = inp.unsqueeze(0)
254257
tmp = inp.shape[0]
@@ -279,6 +282,11 @@ def add_batch(self, layer, name, inp, out):
279282
inp = math.sqrt(2 / self.layers_cache[name]['nsamples']) * inp.float()
280283
self.layers_cache[name]['H'] += inp.matmul(inp.t())
281284

285+
dist.all_reduce(self.layers_cache[name]['H'], op=dist.ReduceOp.SUM)
286+
dist.all_reduce(torch.tensor(self.layers_cache[name]['nsamples']).cuda(),
287+
op=dist.ReduceOp.SUM)
288+
self.layers_cache[name]['H'] /= world_size
289+
282290
@torch.no_grad()
283291
def layer_init(self, layer, name):
284292
W = layer.weight.data.clone()

0 commit comments

Comments
 (0)