-
Notifications
You must be signed in to change notification settings - Fork 345
[CPU] add Float8OpaqueTensor for dynamic float8 act float8 weight #3075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3075
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 7980de8 with merge base 838dceb ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
CC @mingfeima for review. Thanks. |
Hi @mingfeima @jerryzh168 @andrewor14 Could you please review this PR? Thanks. |
test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py
Show resolved
Hide resolved
Hi @mingfeima @jerryzh168 @andrewor14 Though this PR depends on #3100, could you please review this PR? Thanks. |
float8_dtype=torch.float8_e4m3fn, | ||
block_size=block_size, | ||
) | ||
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to use
ao/torchao/quantization/quant_primitives.py
Line 2425 in c96f2dd
def _quantize_affine_float8_non_decomposed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Since we are not using Inductor for fusion like PT2E, it should be OK here.
torchao/float8/inference.py
Outdated
return processed_granularity | ||
|
||
|
||
def _normalize_granularity_opaque_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this can't reuse the other normalize_granularity_tensor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Updated
torchao/float8/types.py
Outdated
|
||
# Define FP8Granularity type alias to break circular import dependencies | ||
FP8Granularity = Union["PerTensor", "PerRow"] | ||
FP8GranularityCPU = Union["PerTensor", "PerRow", "PerGroup"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we can reuse and extend FP8Granularity
and assert only part of the options are supported for GPU right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Updated.
torchao/quantization/quant_api.py
Outdated
block_size = get_block_size(x.shape, activation_granularity) | ||
else: | ||
group_size = activation_granularity.group_size | ||
block_size = (*([1] * (len(x.shape) - 1)), group_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this not included in get_block_size
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
torchao/quantization/quant_api.py
Outdated
_check_hardware_support(granularity) | ||
is_cpu = weight.device.type == "cpu" | ||
if not is_cpu: | ||
_check_hardware_support(granularity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move this to version 1? and then version 2 can do this check in the tensor itself probably
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Thanks.
torchao/quantization/quant_api.py
Outdated
if not is_cpu and not _fp8_mm_compat(weight): | ||
# TODO(future PR): this should really throw an exception instead of silently | ||
# not doing what the user asked | ||
return weight | ||
|
||
if isinstance(weight_granularity, PerRow): | ||
if not is_cpu and isinstance(weight_granularity, PerRow): | ||
assert weight.dtype == torch.bfloat16, ( | ||
"PerRow quantization only works for bfloat16 precision input weight" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also these checks, I feel we can move these to version 1 branch for now and deprecate later, we can add the checks to tensors for version 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving this to version=1
branch causes CI failures. I will keep them as is. Maybe it can be improved later. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The data type check is kept here and the _fp8_mm_compat
check is moved to version=1
. Thanks.
@jerryzh168 Could you please review this PR again? Thanks. |
] | ||
], | ||
supported_granularities: tuple[FP8Granularity] = (PerTensor, PerRow), | ||
support_different_granularities: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is weird, I think we should have normalize_granularity to only do normalize, not also validation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel the same actually. Where should we put the validation? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to _normalize_and_validate_granularities
, can you define separate functions for both float8 tensor and float8 opque tensor in the tensor file itself? i.e. float8_tensor.py and float8_opque_tensor.py
probably will be clearer if you do this in a separate PR, that is move the original _normalize
function to float8_tensor.py and change all the callsites first, and then in this PR you just need to add a new one for float8_opque_tensor.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Will do. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about version=1? Call _validate_granularity
explicitly? In that case, _validate_granularity
cannot be bound to a specific tensor type I guess. And _normalize_granularity
(with checks) is called elsewhere too:
ao/torchao/quantization/quant_api.py
Line 1945 in 838dceb
activation_granularity, weight_granularity = _normalize_granularity(granularity) (act_granularity, weight_granularity) = _normalize_granularity(
How shall we do validation at these locations?
Summary
We split the original big PR #2505 into the following smaller ones:
Test plan