-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat]: Add support for kleidiai quantization schemes #1447
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1447
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @ng-05! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Hello @jerryzh168 ,
How should we take this input from user regarding quantization schemes. Groupsize parameter can not server the purpose as channelsize will change for diff matmuls in a model? Currently I am using "scheme" parameter to differentiate between the two. |
yeah, you can use https://github.com/pytorch/ao/blob/main/torchao/quantization/granularity.py: PerGroup and PerAxis(axis=0) (assuming channel dimension is 0), examples: ao/torchao/quantization/quant_api.py Line 1069 in 070345d
ao/tutorials/calibration_flow/static_quant.py Line 168 in 070345d
|
Thanks for the inputs @jerryzh168. I have initial change ready which extends int4_weight_only quantizer. The 4 bit KleidiAI kernels quantizes the weight in torchao and input to 8 bit within the kernel itself instead of quantizing the input in the torchao the way int8_dynamic_activation_int4_weight does. Currently neither int4_weight_only nor int8_dynamic_activation_int4_weight fully aligns with the way kelidiai 4 bit kernels are working. I feel int4_weight_only is closest to what we want to do, what are your thoughts on this? |
yeah int4_weight_only means no input quantization, I think it aligns better with we also have ao/torchao/experimental/quant_api.py Line 485 in 4738377
You can also check out: #995 |
738d7f2
to
358d6b4
Compare
Hello @jerryzh168 , I am planning to migrate Can you please review this change, specially the change the in I am also not sure if int8_dynamic_activation_intx_weight* quantizer can be accessed by torchchat currently? Do you have an example how torchchat can pass args like granularity, mapping_type from torchchat cli to torchao ? |
target: Target | ||
|
||
# Allow bias access via layout | ||
bias: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layout is more of a "type" actually, why is bias Tensor passed here?
the corresponding "storage" is TensorImpl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to access the bias to be packed with my weights and scales. I can not find any other existing way to pass bias to from_plain()
api via
ao/torchao/dtypes/affine_quantized_tensor.py
Line 281 in ad61822
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) |
How do you think I should access bias in the packing function here.
def _pack_weights_native( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ng-05 - bias
is not required to differentiate this layout i.e. you can dispatch to this layout with and without bias.
That said, @jerryzh168 - we do need to figure out how to get the bias to the from_plain
method. I know it doesn't play nice with the tensor
representation abstraction for AQT
, do you have any other suggestions?
Perhaps until then can we just do a add op followed by gemm, and put a TODO on fixing APIs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it does not fit into AQT
, I think it's fine to create a new tensor subclass, but putting bias
Tensor in the layout is bit conflicting the design (has_bias boolean is fine) since it's a "type", should not store data there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layout is following the design of https://pytorch.org/docs/stable/tensor_attributes.html#torch.layout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me overall, can you add some tests?
I don't think we need to expose these fine grained args to torchchat cli, we just need these high level args like: https://github.com/pytorch/torchchat/blob/main/torchchat/quant_config/mobile.json we are also working on migrating torchchat to use torchao quant api btw |
torchchat does not currently use int8_dynamic_activation_intx_weight, but instead a submodule swap API here: https://github.com/pytorch/ao/blob/main/torchao/experimental/quant_api.py#L438 We will be switching torchchat to use int8_dynamic_activation_intx_weight instead, but I first need to land some changes for perf/clarity: #1553 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this quant API now connects kernels we landed in aten with quant API. If the kernels you guys landed in aten are actually new ops, unlike int4pack_mm and friends, then why did we land them there in the first place. In order to reach those kernels you need ao dep anyway? (@digantdesai I know you tagged me on that PR but i never really deep dived into that so maybe you have context here)
Besides taht i have a couple of questions.
- In the current form it is only making aten op you guys added available via tensor subclass api, so what happens to say torch.compile (maybe this works?) or AOTI usecase?
- I would also like to see if we can leverage this op in executorch, for which integration into AO would have been a better choice compared to this being aten op
- If kleidi's op performs better than whats in this repo (and note that @digantdesai has actually integrated some of the kleidi ops that I guess you guys are aware of), then can we just use that op directly or have a path to kleidi's impl for the cpu ops that exist under experimental/ops?
Any specific reason why use subclass API instead of module swap? |
I am unaware of executorch status and what performance you get with klediai kernels over there. I tested this change with torch.compile() and it seems to be working fine. |
@jerryzh168 @kimishpatel are we testing the 4 bit symmetric quantization anywhere without adding a dequant layer on the result? In my testing I am seeing very poor accuracy with symmetric 4 bit quant scheme with this PR. |
target: Target | ||
|
||
# Allow bias access via layout | ||
bias: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ng-05 - bias
is not required to differentiate this layout i.e. you can dispatch to this layout with and without bias.
That said, @jerryzh168 - we do need to figure out how to get the bias to the from_plain
method. I know it doesn't play nice with the tensor
representation abstraction for AQT
, do you have any other suggestions?
Perhaps until then can we just do a add op followed by gemm, and put a TODO on fixing APIs?
My understanding from @jerryzh168 is that long-term, torchao plans to support pt2e and subclass/quantize_ based quantization long-term. I believe torchchat is working on (and has already partially completed) moving module-swap based quantization over to use quantize_ (cc @Jack-Khuu to keep me honest there). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think with extended AQT it is close. Left some comments.
): | ||
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) | ||
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" | ||
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to assert bias is None when layout.target != ATEN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if bias is none and target layout.target != ATEN then execution will never reach this point. It will be executed via the "native" target which has bias assert. Please check here: https://github.com/pytorch/ao/pull/1447/files#diff-3e4ffa192fe9921999dd6a798fc3fa620377896ef9ba65245b1e5ab8c0d7d344R593
"""Enum that indicates the backend target""" | ||
|
||
NATIVE = auto() | ||
ATEN = auto() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What makes it ATEN specific? Should it be Kleidi? I am thinking from longer term perspective where we will use this layout arg to differentiate how to pack weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to stick to aten so that more ops from aten can be added in future if needed. This enables this layout to work with torchao ops and aten ops.
dyn_quant_matmul_4bit
is not supposed to be KleidiAI specific only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the aten implementation and the packing routine it selects is Kleidi specific.
register_aqt_quantized_linear_dispatch( | ||
_linear_check, | ||
_linear_impl, | ||
) | ||
|
||
|
||
class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, (1) why adding bias as optional to_affine_quantized_intx
isn't an option? or alternatively, (2) for native kernels we will add bias support sooner rather than later, so should we use this for both native and aten?
We need to check for |
@digantdesai Please check This change was suggested by @jerryzh168 |
@pytorchbot label "arm priority" |
Didn't find following labels among repository labels: arm priority |
What about (2)? This is where I am trying to figure out how to make both native and aten have bias, and set it up for the future where we will, potentially delegate aten impl. |
I am ok with merging this for now, and since we will have to figure out details between native and aten for intx, we can resolve this later. Stamping to unblock you. |
@pytorchbot merge |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 5 jobs have failed, first few of them are: Code Analysis with Ruff, PR Label Check, Build Docs, Run Float8 Tests, Run Regression Tests Details for Dev Infra teamRaised by workflow job |
I added CI to run torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py so you don't have to do it locally (.github/workflows/torchao_experimental_test.yml). Just make sure the "Run TorchAO Experimental Tests" job finishes successfully before merging. I don't see any tests for your new quantization code added. For now, can you add a new test file (e.g., torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py) that tests your quantizer/kernel code. Later I can merge the existing test_packed_linear_int8_dynamic_activation_intx_weight_layout.py and the tests for aten target in the new file you create into one test. You can your test to CI by adding a line here: https://github.com/pytorch/ao/blob/main/.github/workflows/torchao_experimental_test.yml#L42 |
Good point about tests, @ng-05 can we please add them? Also in the CI now? |
Description: Allow int8_dynamic_activation_intx_weight to work with aten _dyn_quant_matmul_4bit op Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0 Signed-off-by: Nikhil Gupta <[email protected]>
Signed-off-by: Nikhil Gupta <[email protected]>
Signed-off-by: Nikhil Gupta <[email protected]>
0f49f3d
to
952ab42
Compare
Signed-off-by: Nikhil Gupta <[email protected]>
952ab42
to
eb61a24
Compare
Done. |
Thank you very much. I have done the requested changes around tests. Please review |
Description:
Allow int8_dynamic_activation_intx_weight to work with aten _dyn_quant_matmul_4bit op
Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0