Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gptq dp #248

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llmc/compression/quantization/gptq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import copy
import functools
import math
import os
from abc import ABCMeta, abstractmethod
from collections import defaultdict

import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
from loguru import logger
Expand Down Expand Up @@ -249,6 +251,7 @@ def cache_input_hook(self, m, inp, out, name, feat_dict):

@torch.no_grad()
def add_batch(self, layer, name, inp, out):
world_size = int(os.environ['WORLD_SIZE'])
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
Expand Down Expand Up @@ -279,6 +282,11 @@ def add_batch(self, layer, name, inp, out):
inp = math.sqrt(2 / self.layers_cache[name]['nsamples']) * inp.float()
self.layers_cache[name]['H'] += inp.matmul(inp.t())

dist.all_reduce(self.layers_cache[name]['H'], op=dist.ReduceOp.SUM)
dist.all_reduce(torch.tensor(self.layers_cache[name]['nsamples']).cuda(),
op=dist.ReduceOp.SUM)
self.layers_cache[name]['H'] /= world_size

@torch.no_grad()
def layer_init(self, layer, name):
W = layer.weight.data.clone()
Expand Down
Loading