Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor subclass] print type_string of tensor attributes #1592

Open
wants to merge 12 commits into
base: subclass/check_tensor_attrs
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Dec 29, 2024

What does this PR do?

as per title.

import thunder
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
from thunder.tests.test_tensor_subclass import ScaleTensorSubclass
import thunder.tests.test_tensor_subclass
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, scale):
  # x: "cpu f32[2, 2]"
  # scale: "cpu f32[]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/tests/test_tensor_subclass.py:68:                   x.size(),
  (_, _) = prims.shape(x)

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/proxies.py:1965:                   self.requires_grad,
  self = ScaleTensorSubclass(x, scale)  # self: "ScaleTensorSubclass[cpu f32[2, 2]] (cpu f32[2, 2], cpu f32[])"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/tests/test_tensor_subclass.py:68:                   x.size(),
  (_, _) = prims.shape(x)

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/proxies.py:1965:                   self.requires_grad,
  output = ScaleTensorSubclass(x, scale)  # output: "ScaleTensorSubclass[cpu f32[2, 2]] (_x: cpu f32[2, 2], _scale: cpu f32[])"
  return (output,)

@crcrpar crcrpar force-pushed the subclass_tensor-type-str branch from 0931c59 to 4a16355 Compare December 30, 2024 12:18
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Dec 30, 2024
@crcrpar crcrpar force-pushed the subclass/check_tensor_attrs branch from 1069c05 to edfd224 Compare December 30, 2024 12:21
@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 31, 2024

currently thunder/tests/test_apex_fused_norms.py is failing in my environment and its cause seems to be the dce's called inside jit_ext's _general_jit_torch_autograd_function_apply_lookaside

8435406 would fix this.

@crcrpar crcrpar force-pushed the subclass_tensor-type-str branch from 4a16355 to de00edd Compare January 2, 2025 11:09
@github-actions github-actions bot removed the documentation Improvements or additions to documentation label Jan 2, 2025
@crcrpar crcrpar force-pushed the subclass/check_tensor_attrs branch from edfd224 to 8e6d110 Compare January 2, 2025 13:52
@crcrpar crcrpar force-pushed the subclass_tensor-type-str branch from 70bb2f2 to 07a96d6 Compare January 2, 2025 13:53
crcrpar and others added 12 commits January 29, 2025 15:40
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
to support `__torch_dispatch__`.
Since it extends the behavior that is implemented in C++ level,
we'd need to apply the transform to split forward and backward traces
separately.

Signed-off-by: Masaki Kozuki <[email protected]>
to support `__torch_dispatch__`.
Since it extends the behavior that is implemented in C++ level,
we'd need to apply the transform to split forward and backward traces
separately.

Signed-off-by: Masaki Kozuki <[email protected]>
- Add `scaled_mm`
- Change how the lookaside of `torch.autograd.Function.apply` applies dce
taking the failure of apex fused rms norm into consideration.

```python
@torch.no_grad()
@no_autocast
def FusedRMSNormAffineMixedDtypesFunction(t_0, t_1, tup11, f12, b13):
  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128:                 weight_ = weight.contiguous()
  # t_0: "cuda:0 f32[4, 5, 3, 2]"
  # t_1: "cuda:0 f32[3, 2]"

  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:127:                 input_ = input.contiguous()
  t5 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0)  # t5: "cuda:0 f32[4, 5, 3, 2]"
    # t5 = prims.stride_order(t_0, (3, 2, 1, 0))  # t5: "cuda:0 f32[4, 5, 3, 2]"

  # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128:                 weight_ = weight.contiguous()
  t6 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0)  # t6: "cuda:0 f32[3, 2]"
    # t6 = prims.stride_order(t_1, (1, 0))  # t6: "cuda:0 f32[3, 2]"
  (t10, t9) = apex_fused_rms_norm_forward_affine_mixed_dtypes(t5, (3, 2), t6, 1e-05)
  return t10
```
For this trace, `thunder.core.transforms.dce` replaces `t9` with `_`
then the augmented forward trace would lose the access to it. So by
reusing the augmented forward trace in the basic forward trace, `dce`
would not do so.

Signed-off-by: Masaki Kozuki <[email protected]>
also use `pytorch_executor` in the `transform_for_execution` of
`prologue_trace` as it could have the prim of tensor subclass flattening
whose definition is only available in pytorch executor.

Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the subclass_tensor-type-str branch from 07a96d6 to 9f0a8ef Compare February 6, 2025 13:30
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant