Skip to content

Commit

Permalink
Support gptq dp (#248)
Browse files Browse the repository at this point in the history
Co-authored-by: gushiqiao <[email protected]>
  • Loading branch information
gushiqiao and gushiqiao authored Dec 9, 2024
1 parent 2f3e425 commit d2ab632
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions llmc/compression/quantization/gptq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import copy
import functools
import math
import os
from abc import ABCMeta, abstractmethod
from collections import defaultdict

import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
from loguru import logger
Expand Down Expand Up @@ -249,6 +251,7 @@ def cache_input_hook(self, m, inp, out, name, feat_dict):

@torch.no_grad()
def add_batch(self, layer, name, inp, out):
world_size = int(os.environ['WORLD_SIZE'])
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
Expand Down Expand Up @@ -279,6 +282,11 @@ def add_batch(self, layer, name, inp, out):
inp = math.sqrt(2 / self.layers_cache[name]['nsamples']) * inp.float()
self.layers_cache[name]['H'] += inp.matmul(inp.t())

dist.all_reduce(self.layers_cache[name]['H'], op=dist.ReduceOp.SUM)
dist.all_reduce(torch.tensor(self.layers_cache[name]['nsamples']).cuda(),
op=dist.ReduceOp.SUM)
self.layers_cache[name]['H'] /= world_size

@torch.no_grad()
def layer_init(self, layer, name):
W = layer.weight.data.clone()
Expand Down

0 comments on commit d2ab632

Please sign in to comment.