Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat]: Add support for kleidiai quantization schemes #1447

Merged
merged 4 commits into from
Jan 30, 2025

Conversation

ng-05
Copy link
Contributor

@ng-05 ng-05 commented Dec 19, 2024

Description:
Allow int8_dynamic_activation_intx_weight to work with aten _dyn_quant_matmul_4bit op

Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0

Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1447

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @ng-05!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@ng-05 ng-05 marked this pull request as draft December 19, 2024 10:44
@ng-05
Copy link
Contributor Author

ng-05 commented Jan 8, 2025

Hello @jerryzh168 ,
We want to support two diff type of int4 schemes.

  1. symmetric_groupwise -> groupsize [ 32, 64, 128 etc ]
  2. symmetric_channelwise -> groupsize is equal to channelsize of the matmul weights

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

Currently I am using "scheme" parameter to differentiate between the two.
aarch64_cpu_channelwise.json
aarch64_cpu_groupwise.json

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jan 8, 2025

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

yeah, you can use https://github.com/pytorch/ao/blob/main/torchao/quantization/granularity.py: PerGroup and PerAxis(axis=0) (assuming channel dimension is 0), examples:

granularity: Optional[
,
weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 9, 2025

How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model?

yeah, you can use https://github.com/pytorch/ao/blob/main/torchao/quantization/granularity.py: PerGroup and PerAxis(axis=0) (assuming channel dimension is 0), examples:

granularity: Optional[

,

weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)

Thanks for the inputs @jerryzh168.

I have initial change ready which extends int4_weight_only quantizer.

The 4 bit KleidiAI kernels quantizes the weight in torchao and input to 8 bit within the kernel itself instead of quantizing the input in the torchao the way int8_dynamic_activation_int4_weight does.
For this reason I am extending the int4_weight_only api. I am slightly confused if the intention of this api is to convey NO input quantisation to user?

Currently neither int4_weight_only nor int8_dynamic_activation_int4_weight fully aligns with the way kelidiai 4 bit kernels are working.

I feel int4_weight_only is closest to what we want to do, what are your thoughts on this?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jan 9, 2025

I feel int4_weight_only is closest to what we want to do, what are your thoughts on this?

yeah int4_weight_only means no input quantization, I think it aligns better with int8_dynamic_activation_int4_weight, you can use a different layout and customize the logic for input quantization.

we also have

def int8_dynamic_activation_intx_weight(
that is the same as your use case. there is some ongoing refactors/updates there as well right now

You can also check out: #995

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 11, 2025

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2.
For now I have kept the API separate for review and testing.

Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

target: Target

# Allow bias access via layout
bias: Optional[torch.Tensor] = None
Copy link
Contributor

@jerryzh168 jerryzh168 Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layout is more of a "type" actually, why is bias Tensor passed here?

the corresponding "storage" is TensorImpl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to access the bias to be packed with my weights and scales. I can not find any other existing way to pass bias to from_plain() api via

tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)

How do you think I should access bias in the packing function here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ng-05 - bias is not required to differentiate this layout i.e. you can dispatch to this layout with and without bias.

That said, @jerryzh168 - we do need to figure out how to get the bias to the from_plain method. I know it doesn't play nice with the tensor representation abstraction for AQT, do you have any other suggestions?

Perhaps until then can we just do a add op followed by gemm, and put a TODO on fixing APIs?

Copy link
Contributor

@jerryzh168 jerryzh168 Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it does not fit into AQT, I think it's fine to create a new tensor subclass, but putting bias Tensor in the layout is bit conflicting the design (has_bias boolean is fine) since it's a "type", should not store data there

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me overall, can you add some tests?

@jerryzh168
Copy link
Contributor

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

I don't think we need to expose these fine grained args to torchchat cli, we just need these high level args like: https://github.com/pytorch/torchchat/blob/main/torchchat/quant_config/mobile.json

we are also working on migrating torchchat to use torchao quant api btw

@metascroy
Copy link
Contributor

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2. For now I have kept the API separate for review and testing.

Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.

I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: https://github.com/pytorch/ao/blob/main/torchao/experimental/quant_api.py#L438

We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this quant API now connects kernels we landed in aten with quant API. If the kernels you guys landed in aten are actually new ops, unlike int4pack_mm and friends, then why did we land them there in the first place. In order to reach those kernels you need ao dep anyway? (@digantdesai I know you tagged me on that PR but i never really deep dived into that so maybe you have context here)

Besides taht i have a couple of questions.

  • In the current form it is only making aten op you guys added available via tensor subclass api, so what happens to say torch.compile (maybe this works?) or AOTI usecase?
  • I would also like to see if we can leverage this op in executorch, for which integration into AO would have been a better choice compared to this being aten op
  • If kleidi's op performs better than whats in this repo (and note that @digantdesai has actually integrated some of the kleidi ops that I guess you guys are aware of), then can we just use that op directly or have a path to kleidi's impl for the cpu ops that exist under experimental/ops?

@kimishpatel
Copy link
Contributor

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2. For now I have kept the API separate for review and testing.
Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.
I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: main/torchao/experimental/quant_api.py#L438

We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553

Any specific reason why use subclass API instead of module swap?

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 13, 2025

I understand that this quant API now connects kernels we landed in aten with quant API. If the kernels you guys landed in aten are actually new ops, unlike int4pack_mm and friends, then why did we land them there in the first place. In order to reach those kernels you need ao dep anyway? (@digantdesai I know you tagged me on that PR but i never really deep dived into that so maybe you have context here)

Besides taht i have a couple of questions.

  • In the current form it is only making aten op you guys added available via tensor subclass api, so what happens to say torch.compile (maybe this works?) or AOTI usecase?
  • I would also like to see if we can leverage this op in executorch, for which integration into AO would have been a better choice compared to this being aten op
  • If kleidi's op performs better than whats in this repo (and note that @digantdesai has actually integrated some of the kleidi ops that I guess you guys are aware of), then can we just use that op directly or have a path to kleidi's impl for the cpu ops that exist under experimental/ops?

I am unaware of executorch status and what performance you get with klediai kernels over there. I tested this change with torch.compile() and it seems to be working fine.

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 13, 2025

@jerryzh168 @kimishpatel are we testing the 4 bit symmetric quantization anywhere without adding a dequant layer on the result? In my testing I am seeing very poor accuracy with symmetric 4 bit quant scheme with this PR.
For comparison, the mean relative error jumps from 0.0006 (with llama.cpp algo ) to 0.0044 (torchao algo)with kleidiai kernels.
This is the reference scheme that I am using for the 4 bit symmetric quant. ggerganov/llama.cpp#729

target: Target

# Allow bias access via layout
bias: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ng-05 - bias is not required to differentiate this layout i.e. you can dispatch to this layout with and without bias.

That said, @jerryzh168 - we do need to figure out how to get the bias to the from_plain method. I know it doesn't play nice with the tensor representation abstraction for AQT, do you have any other suggestions?

Perhaps until then can we just do a add op followed by gemm, and put a TODO on fixing APIs?

torchao/experimental/quant_api.py Outdated Show resolved Hide resolved
@metascroy
Copy link
Contributor

Hello @jerryzh168 , I am planning to migrate int8_dynamic_activation_intx_weight api to int8_dynamic_activation_intx_weight_v2. For now I have kept the API separate for review and testing.
Can you please review this change, specially the change the in _get_linear_subclass_inserter which allow bias propagation. The bias needed by torch.ops.aten._dyn_quant_pack_4bit_weight.
I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ?

torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: main/torchao/experimental/quant_api.py#L438
We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553

Any specific reason why use subclass API instead of module swap?

My understanding from @jerryzh168 is that long-term, torchao plans to support pt2e and subclass/quantize_ based quantization long-term. I believe torchchat is working on (and has already partially completed) moving module-swap based quantization over to use quantize_ (cc @Jack-Khuu to keep me honest there).

Copy link
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with extended AQT it is close. Left some comments.

):
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to assert bias is None when layout.target != ATEN?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if bias is none and target layout.target != ATEN then execution will never reach this point. It will be executed via the "native" target which has bias assert. Please check here: https://github.com/pytorch/ao/pull/1447/files#diff-3e4ffa192fe9921999dd6a798fc3fa620377896ef9ba65245b1e5ab8c0d7d344R593

"""Enum that indicates the backend target"""

NATIVE = auto()
ATEN = auto()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes it ATEN specific? Should it be Kleidi? I am thinking from longer term perspective where we will use this layout arg to differentiate how to pack weights.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to stick to aten so that more ops from aten can be added in future if needed. This enables this layout to work with torchao ops and aten ops.
dyn_quant_matmul_4bit is not supposed to be KleidiAI specific only

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the aten implementation and the packing routine it selects is Kleidi specific.

register_aqt_quantized_linear_dispatch(
_linear_check,
_linear_impl,
)


class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, (1) why adding bias as optional to_affine_quantized_intx isn't an option? or alternatively, (2) for native kernels we will add bias support sooner rather than later, so should we use this for both native and aten?

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 23, 2025

I can not run this test as the op build is broken for aarch64 linux.

Can you remind me of this issue? how can we resolve this?

We need to check for aarch64 in cmake processor along with arm64. Even after that I am seeing few build issues. I did not have a deeper look at them.

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 23, 2025

@digantdesai
for this comment :
Curious, (1) why adding bias as optional to_affine_quantized_intx isn't an option? or alternatively, (2) for native kernels we will add bias support sooner rather than later, so should we use this for both native and aten?

Please check
#1447 (comment) and #1447 (comment)

This change was suggested by @jerryzh168

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 24, 2025

@pytorchbot label "arm priority"

Copy link

pytorch-bot bot commented Jan 24, 2025

Didn't find following labels among repository labels: arm priority

@digantdesai
Copy link
Contributor

digantdesai commented Jan 27, 2025

@digantdesai for this comment : Curious, (1) why adding bias as optional to_affine_quantized_intx isn't an option? or alternatively, (2) for native kernels we will add bias support sooner rather than later, so should we use this for both native and aten?

Please check #1447 (comment) and #1447 (comment)

This change was suggested by @jerryzh168

What about (2)? This is where I am trying to figure out how to make both native and aten have bias, and set it up for the future where we will, potentially delegate aten impl.

@digantdesai
Copy link
Contributor

I am ok with merging this for now, and since we will have to figure out details between native and aten for intx, we can resolve this later.

Stamping to unblock you.

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 28, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: superuser

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 28, 2025
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 28, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 5 jobs have failed, first few of them are: Code Analysis with Ruff, PR Label Check, Build Docs, Run Float8 Tests, Run Regression Tests

Details for Dev Infra team Raised by workflow job

@metascroy
Copy link
Contributor

Overall I think it's close. Can you be sure to run torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py to make sure the tests pass. It is not currently enabled in CI.
You should also add some test cases for your new target to that file to check accuracy/exportability.
I'll be out for the week of January 20, but @digantdesai should be able to help with review during that time.

I can not run this test as the op build is broken for aarch64 linux.

I added CI to run torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py so you don't have to do it locally (.github/workflows/torchao_experimental_test.yml). Just make sure the "Run TorchAO Experimental Tests" job finishes successfully before merging.

I don't see any tests for your new quantization code added. For now, can you add a new test file (e.g., torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py) that tests your quantizer/kernel code. Later I can merge the existing test_packed_linear_int8_dynamic_activation_intx_weight_layout.py and the tests for aten target in the new file you create into one test.

You can your test to CI by adding a line here: https://github.com/pytorch/ao/blob/main/.github/workflows/torchao_experimental_test.yml#L42

@digantdesai
Copy link
Contributor

Good point about tests, @ng-05 can we please add them? Also in the CI now?

ng-05 added 3 commits January 30, 2025 14:08
Description:

Allow int8_dynamic_activation_intx_weight to work with aten  _dyn_quant_matmul_4bit op

Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0

Signed-off-by: Nikhil Gupta <[email protected]>
@ng-05
Copy link
Contributor Author

ng-05 commented Jan 30, 2025

Good point about tests, @ng-05 can we please add them? Also in the CI now?

Done.

@ng-05
Copy link
Contributor Author

ng-05 commented Jan 30, 2025

Overall I think it's close. Can you be sure to run torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py to make sure the tests pass. It is not currently enabled in CI.
You should also add some test cases for your new target to that file to check accuracy/exportability.
I'll be out for the week of January 20, but @digantdesai should be able to help with review during that time.

I can not run this test as the op build is broken for aarch64 linux.

I added CI to run torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py so you don't have to do it locally (.github/workflows/torchao_experimental_test.yml). Just make sure the "Run TorchAO Experimental Tests" job finishes successfully before merging.

I don't see any tests for your new quantization code added. For now, can you add a new test file (e.g., torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py) that tests your quantizer/kernel code. Later I can merge the existing test_packed_linear_int8_dynamic_activation_intx_weight_layout.py and the tests for aten target in the new file you create into one test.

You can your test to CI by adding a line here: https://github.com/pytorch/ao/blob/main/.github/workflows/torchao_experimental_test.yml#L42

Thank you very much. I have done the requested changes around tests. Please review

@digantdesai digantdesai merged commit 7815262 into pytorch:main Jan 30, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants