Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

key/value shape mismatch in awq/modules/fused/cache.py #716

Open
DominikHil opened this issue Feb 20, 2025 · 0 comments
Open

key/value shape mismatch in awq/modules/fused/cache.py #716

DominikHil opened this issue Feb 20, 2025 · 0 comments

Comments

@DominikHil
Copy link

DominikHil commented Feb 20, 2025

Hello,

first I'd like to give a big "thank you" to the authors for creating and maintaining this awesome repo!

I'm trying to run some benchmarks using AWQ. However, I'm encountering the following error:

[truncated]
[...]/python3.10/site-packages/awq/modules/fused/cache.py", line 61, in roll_kv_n_steps
    self.k[:, :, :, -n:, :] = 0
IndexError: too many indices for tensor of dimension 4

Looking into the mentioned file, I noticed that keys and values are initialized with the same shape, but ''roll_kv_n_steps'' (line 48) and ''decrease_batch_size'' (line 76) functions treat keys as having an additional dimension.

While I'm not the most familiar with KV-caching, it is surprising that there seems to be a shape mismatch. More so since update_kv (line 41) treats keys and values as having matching shapes.

What is the extra dimension of the keys used for?

Going further, the last commit that changed this file (dfe396a#diff-fa6fa842a1ca1adc4cac01e50a41f72d7ee126c75bf1e872e454584c3c9c62c7) seems to have changed shape definition of keys + values (before, keys were initialized with 5 dimensions which was changed to 4) and update_kv was adapted for this change in key-shape, but not roll_kv_n_steps and decrease_batch_size.

I tried to create a minimal reproducible example (see attachment) in which llama 3.0 8B is quantized using group size of 256 and layer fusing:

  1. Setup (Linux):
conda create -n awq python=3.10
conda activate awq
conda install cuda -c nvidia
pip3 install torch torchvision torchaudio
pip install autoawq[kernels]
  1. Quantize:
quant_path = ""  # change to dir of where you want to save the quantized model

# quantization takes ~15 min
model = AutoAWQForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B",
                                               safetensors=True,
                                               torch_dtype="float16",
                                               device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

quant_config = {"w_bit": 4,
                "q_group_size": 256,
                "zero_point": 1,
                "version": "GEMM"}

model.quantize(tokenizer, quant_config=quant_config)
model.save_quantized(quant_path)
  1. Benchmark:
quant_path = ""  # change to dir where you saved quantized model

model = AutoAWQForCausalLM.from_quantized(quant_path,
                                          safetensors=True,
                                          torch_dtype=torch.float16,
                                          device_map="auto",
                                          quantization_config=AwqConfig(version=AWQLinearVersion.GEMM,
                                                                        nbits=4,
                                                                        group_size=256,
                                                                        do_fuse=True,
                                                                        fuse_max_seq_len=4096)
                                          )
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

ds = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

# ppl measure based on: https://huggingface.co/docs/transformers/perplexity
encodings = tokenizer("\n\n".join(ds["text"]), return_tensors="pt")
seq_len = encodings.input_ids.size(1)

nlls = []
prev_end_loc = 0
times = []
stride=512
max_length=1024

for begin_loc in range(0, seq_len, stride):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.model.device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

print(f"Wikitext word ppl: {torch.exp(torch.stack(nlls).mean())}")

Result of 2. for original code throws the above error.

However, when I change the mentioned line in "awq/modules/fused/cache.py" s.t. key-shape matches value-shape:

line 59:
from

self.k[:, :, :, -n:, :] = 0

to

self.k[:, :, -n:, :] = 0

then the code from 2. runs with output:
Wikitext word ppl: 6.297884941101074

Would be awesome if someone could take a look at this.

If there is indeed an issue with the above line then the following lines could also be considered for change:

line 55:
from

self.k = torch.roll(self.k, shifts=-n, dims=3)

to

self.k = torch.roll(self.k, shifts=-n, dims=2)

line 79:
from

self.k = self.k[:to_bsz, :, :, :, :]

to

self.k = self.k[:to_bsz, :, :, :]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant