Skip to content

[Transforms] Enable shared memory and introduce permutations #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: transform_apply_support
Choose a base branch
from

Conversation

dsikka
Copy link
Collaborator

@dsikka dsikka commented Mar 26, 2025

Summary

1. Makes a series of updates to the registry by supporting matrix caching of the transform parameter

  • With these changes, instead of generating hadamards for a particular size repeatedly, cache the matrix using its size as the key. This significantly speeds up transform set-up as by leveraging the property of hadamards being deterministic
  • For randomized hadamards, splits up the math so that the underlying hadamard can be cached and the randomness is introduced as a separate permutation matrix [will need to add a few tests to make sure my math is correct]
  • For the matrix-multiply transform, caches based on a user provided name

2. Also uses shared memory, so that layers with identical transforms use the same underlying transform data.

  • Significantly reduces the memory required by transforms
  • NOTE: For training, we may need to make updates depending on how transform updates are expected to be made during training as we are now using shared data

3. Move update/register functionality to be done inside the registry; introduce permutation parameter

  • The update and register steps are identical as before but now happen inside the registry for slightly more clarity
  • Introduce the permutation parameter. Only being used for the random-hadamard for now but will follow up with the remaining transfoms as well

4. Swap global to be called "shared"

@dsikka dsikka marked this pull request as ready for review March 26, 2025 23:10
Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this makes sense 👍

@@ -78,7 +79,7 @@ def load_transforms(model: Module, model_name_or_path: str):

state_dict = {}
for weight_name, safe_path in weight_mappings.items():
if "transform" in weight_name:
if "transform" in weight_name or "_perm_" in weight_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like we have to do some sort of name matching, but i'm wondering if some name collision down the road is going to cause this to run when we don't want it? if we came up with a more unique name or something to prevent false positives

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have this problem with any parameter we introduce (e.g weight_scale, weight_g_idx, etc) but yeah, we can work on making them more unique

):
if module is None:
self.transform.data.copy_(data)
if self.permutation is not None and permutation_data is not None:
self.permutation.data.copy_(permutation_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need update_parameter_data here too in case of offloading?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is in the case if the parameter isn't registered to a module for whatever reason. update_parameter_data handles module_params. I'm not sure if this case is totally necessary but yeah, we would have to add offloading/onloading around it if we decide to keep it

@@ -129,7 +135,7 @@ def _matmul_hadU(X, transpose=False):
input = hadK.view(1, K, K).to(input) @ input

# normalize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment might need to go too if we are not normalizing?

__all__ = ["apply_matrix_transform", "SingletonMatrixRegistry"]


class SingletonMatrixRegistry:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so that all matrices live in a single global key-value store, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we can expand this but it seems like there will be a lot of repetition across decoder layers for example
I think if this goes too big in scope, we may have to consider other data stores to handle it

__all__ = ["apply_matrix_transform", "SingletonMatrixRegistry"]


class SingletonMatrixRegistry:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a new class singleton necessary? Since each matrix Transform requires own registry, shouldn't this be implemented on the class itself?

class Hadamard(Transforms):
    registry: Dict[int, torch.Tensor] = {}

    def __new__(cls, size, empty, transform_name, *args, **kwargs):
        if empty:
            matrix = ...
        else:
            matrix = cls.registry.get(size, torch.Tensor(deterministic_hadamard_matrix(size=self.size)))
        return super().__new__(transform=matrix, transform_name=transform_name)

self.transform = torch.nn.Buffer(transform.to(dtype).to(device))
self.transform = torch.nn.Parameter(transform, requires_grad=False)
self.transform_name = transform_name
self.permutation = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For randomized hadamards, splits up the math so that the underlying hadamard can be cached and the randomness is introduced as a separate permutation matrix

Since permutations are specific to RandomHadamards, shouldn't we be implementing this logic on the RandomHadamard class, not the general Transforms class?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants