-
Notifications
You must be signed in to change notification settings - Fork 11
[Transforms] Enable shared memory and introduce permutations #284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: transform_apply_support
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,11 +33,10 @@ class Transforms(RegistryMixin): | |
def __init__( | ||
self, | ||
transform: torch.Tensor, | ||
learnable: Optional[bool] = True, | ||
device: Optional[Union[str, torch.device]] = "cuda", | ||
dtype: Optional[torch.dtype] = torch.bfloat16, | ||
transform_name: str, | ||
permutation: Optional[torch.Tensor] = None, | ||
permutation_name: Optional[str] = None, | ||
): | ||
self.learnable = learnable | ||
""" | ||
Base class for setting up transforms. The registry creates transforms | ||
as parameters which can be attached to modules. | ||
|
@@ -62,33 +61,47 @@ def __init__( | |
|
||
:param transform: transform (e.g. torch.Tensor, scalar) to be applied | ||
""" | ||
if self.learnable: | ||
self.transform = torch.nn.Parameter(transform.to(dtype).to(device)) | ||
else: | ||
self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) | ||
self.transform = torch.nn.Parameter(transform, requires_grad=False) | ||
self.transform_name = transform_name | ||
self.permutation = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Since permutations are specific to |
||
torch.nn.Parameter(permutation, requires_grad=False) | ||
if permutation is not None | ||
else None | ||
) | ||
self.permutation_name = permutation_name | ||
|
||
def update_device(self, module: torch.nn.Module): | ||
# Helper function required for deserialization | ||
module_device = next(module.parameters()).device | ||
self.transform.data = self.transform.data.to(module_device) | ||
if self.permutation is not None: | ||
self.permutation.data = self.permutation.data.to(module_device) | ||
|
||
# register to class for easy offloading, serialization, deserialization | ||
def register_to_module(self, name: str, module: torch.nn.Module): | ||
if self.learnable: | ||
register_offload_parameter(module, name, self.transform) | ||
else: | ||
# TODO: have to verify serialization/offloading | ||
module.register_buffer(name, self.transform) | ||
# TODO: Manage lifecycle of permutation and transform buffers | ||
def register_to_module(self, module: torch.nn.Module): | ||
register_offload_parameter(module, self.transform_name, self.transform) | ||
if self.permutation is not None: | ||
register_offload_parameter(module, self.permutation_name, self.permutation) | ||
|
||
def update_transform( | ||
self, | ||
data: torch.Tensor, | ||
permutation_data: Optional[torch.Tensor] = None, | ||
module: Optional[torch.nn.Module] = None, | ||
name: Optional[str] = None, | ||
): | ||
if module is None: | ||
self.transform.data.copy_(data) | ||
if self.permutation is not None and permutation_data is not None: | ||
self.permutation.data.copy_(permutation_data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is in the case if the parameter isn't registered to a module for whatever reason. update_parameter_data handles module_params. I'm not sure if this case is totally necessary but yeah, we would have to add offloading/onloading around it if we decide to keep it |
||
|
||
else: | ||
# If updating the module parameter data, assumes this is also the transform | ||
# data | ||
if name is None: | ||
raise ValueError("Name and module are required to update parma data") | ||
update_parameter_data(module, data, name) | ||
# is already registered/shared data | ||
update_parameter_data(module, data, self.transform_name) | ||
|
||
if self.permutation is not None and permutation_data is not None: | ||
update_parameter_data(module, permutation_data, self.permutation_name) | ||
|
||
def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,11 @@ | |
import torch | ||
|
||
|
||
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] | ||
__all__ = [ | ||
"random_hadamard_matrix", | ||
"deterministic_hadamard_matrix", | ||
"SingletonHadamardRegistry", | ||
] | ||
|
||
# adapted from: | ||
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py | ||
|
@@ -59,6 +63,7 @@ def deterministic_hadamard_matrix(size: int): | |
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master | ||
|
||
|
||
# ToDo: should no longer be random, call something else --> different generation type than scipy? | ||
def random_hadamard_matrix(size: int) -> torch.Tensor: | ||
""" | ||
Produces a randomly generated Hadamard matrix. | ||
|
@@ -73,7 +78,8 @@ def random_hadamard_matrix(size: int) -> torch.Tensor: | |
# the matrix generated to be reproducible | ||
|
||
# Benefits: support other shapes / non powers of 2, support randomization | ||
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) | ||
# Q = torch.randint(low=1, high=2, size=(size,)).to(torch.float64) | ||
Q = torch.ones(size).to(torch.float64) | ||
Q = Q * 2 - 1 | ||
Q = torch.diag(Q) | ||
return _matmul_hadU(Q) | ||
|
@@ -129,7 +135,7 @@ def _matmul_hadU(X, transpose=False): | |
input = hadK.view(1, K, K).to(input) @ input | ||
|
||
# normalize | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this comment might need to go too if we are not normalizing? |
||
return input.view(X.shape) / torch.tensor(n).sqrt() | ||
return input.view(X.shape) | ||
|
||
|
||
def _is_pow2(n): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like we have to do some sort of name matching, but i'm wondering if some name collision down the road is going to cause this to run when we don't want it? if we came up with a more unique name or something to prevent false positives
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have this problem with any parameter we introduce (e.g weight_scale, weight_g_idx, etc) but yeah, we can work on making them more unique