diff --git a/llmc/compression/quantization/gptq.py b/llmc/compression/quantization/gptq.py index f8fb5c1b..18f6beda 100644 --- a/llmc/compression/quantization/gptq.py +++ b/llmc/compression/quantization/gptq.py @@ -1,10 +1,12 @@ import copy import functools import math +import os from abc import ABCMeta, abstractmethod from collections import defaultdict import torch +import torch.distributed as dist import torch.nn as nn import transformers from loguru import logger @@ -249,6 +251,7 @@ def cache_input_hook(self, m, inp, out, name, feat_dict): @torch.no_grad() def add_batch(self, layer, name, inp, out): + world_size = int(os.environ['WORLD_SIZE']) if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] @@ -279,6 +282,11 @@ def add_batch(self, layer, name, inp, out): inp = math.sqrt(2 / self.layers_cache[name]['nsamples']) * inp.float() self.layers_cache[name]['H'] += inp.matmul(inp.t()) + dist.all_reduce(self.layers_cache[name]['H'], op=dist.ReduceOp.SUM) + dist.all_reduce(torch.tensor(self.layers_cache[name]['nsamples']).cuda(), + op=dist.ReduceOp.SUM) + self.layers_cache[name]['H'] /= world_size + @torch.no_grad() def layer_init(self, layer, name): W = layer.weight.data.clone()