Skip to content

[Transforms] Enable shared memory and introduce permutations #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: transform_apply_support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
from compressed_tensors.utils import has_offloaded_params
from compressed_tensors.utils.safetensors_load import (
get_quantization_state_dict,
get_weight_mappings,
Expand All @@ -78,7 +79,7 @@ def load_transforms(model: Module, model_name_or_path: str):

state_dict = {}
for weight_name, safe_path in weight_mappings.items():
if "transform" in weight_name:
if "transform" in weight_name or "_perm_" in weight_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like we have to do some sort of name matching, but i'm wondering if some name collision down the road is going to cause this to run when we don't want it? if we came up with a more unique name or something to prevent false positives

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have this problem with any parameter we introduce (e.g weight_scale, weight_g_idx, etc) but yeah, we can work on making them more unique

with safe_open(safe_path, framework="pt", device="cpu") as f:
state_dict[weight_name] = f.get_tensor(weight_name)

Expand All @@ -88,11 +89,14 @@ def load_transforms(model: Module, model_name_or_path: str):
if transform_data:
for transform_name, transform_values in transform_data.data.items():
full_name = f"{name}.{transform_name}"
transform_data = state_dict.get(full_name, None)
full_per_name = full_name.replace("transform", "perm")
dict_data = state_dict.get(full_name, None)
permutation_data = state_dict.get(full_per_name, None)
transform = transform_values.get("transform")
transform.register_to_module(name=transform_name, module=submodule)
transform.update_device(module=submodule)
transform.register_to_module(module=submodule)
transform.update_transform(
module=submodule, data=transform_data, name=transform_name
module=submodule, data=dict_data, permutation_data=permutation_data
)


Expand Down Expand Up @@ -180,6 +184,7 @@ def process_transforms_config(
# only support weight parameters for now, assume one value in
# module targets
transform_name = f"{module_targets[0]}_transform_{idx}"
permutation_name = f"{module_targets[0]}_perm_{idx}"

# create an empty tensor OR create a new transform
dtype = getattr(submodule, module_targets[0]).dtype
Expand All @@ -191,17 +196,21 @@ def process_transforms_config(
transform_type,
dtype=dtype,
empty=True,
transform_name=transform_name,
permutation_name=permutation_name,
**transform_creation_args,
)
else:
transform = Transforms.load_from_registry(
transform_type,
dtype=dtype,
transform_name=transform_name,
permutation_name=permutation_name,
device=next(submodule.parameters()).device,
**transform_creation_args,
)
transform.register_to_module(
name=transform_name, module=submodule
)

transform.register_to_module(module=submodule)

# add relevant transform data to the submodule as well
data = {
Expand All @@ -217,6 +226,7 @@ def process_transforms_config(
else:
transform_data = TransformData(data=OrderedDict(data))
submodule.transform_data = transform_data
# 10358 for now mib; 1/3 of memory if caching/sharing parameter data
return model


Expand Down
51 changes: 32 additions & 19 deletions src/compressed_tensors/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ class Transforms(RegistryMixin):
def __init__(
self,
transform: torch.Tensor,
learnable: Optional[bool] = True,
device: Optional[Union[str, torch.device]] = "cuda",
dtype: Optional[torch.dtype] = torch.bfloat16,
transform_name: str,
permutation: Optional[torch.Tensor] = None,
permutation_name: Optional[str] = None,
):
self.learnable = learnable
"""
Base class for setting up transforms. The registry creates transforms
as parameters which can be attached to modules.
Expand All @@ -62,33 +61,47 @@ def __init__(

:param transform: transform (e.g. torch.Tensor, scalar) to be applied
"""
if self.learnable:
self.transform = torch.nn.Parameter(transform.to(dtype).to(device))
else:
self.transform = torch.nn.Buffer(transform.to(dtype).to(device))
self.transform = torch.nn.Parameter(transform, requires_grad=False)
self.transform_name = transform_name
self.permutation = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For randomized hadamards, splits up the math so that the underlying hadamard can be cached and the randomness is introduced as a separate permutation matrix

Since permutations are specific to RandomHadamards, shouldn't we be implementing this logic on the RandomHadamard class, not the general Transforms class?

torch.nn.Parameter(permutation, requires_grad=False)
if permutation is not None
else None
)
self.permutation_name = permutation_name

def update_device(self, module: torch.nn.Module):
# Helper function required for deserialization
module_device = next(module.parameters()).device
self.transform.data = self.transform.data.to(module_device)
if self.permutation is not None:
self.permutation.data = self.permutation.data.to(module_device)

# register to class for easy offloading, serialization, deserialization
def register_to_module(self, name: str, module: torch.nn.Module):
if self.learnable:
register_offload_parameter(module, name, self.transform)
else:
# TODO: have to verify serialization/offloading
module.register_buffer(name, self.transform)
# TODO: Manage lifecycle of permutation and transform buffers
def register_to_module(self, module: torch.nn.Module):
register_offload_parameter(module, self.transform_name, self.transform)
if self.permutation is not None:
register_offload_parameter(module, self.permutation_name, self.permutation)

def update_transform(
self,
data: torch.Tensor,
permutation_data: Optional[torch.Tensor] = None,
module: Optional[torch.nn.Module] = None,
name: Optional[str] = None,
):
if module is None:
self.transform.data.copy_(data)
if self.permutation is not None and permutation_data is not None:
self.permutation.data.copy_(permutation_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need update_parameter_data here too in case of offloading?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is in the case if the parameter isn't registered to a module for whatever reason. update_parameter_data handles module_params. I'm not sure if this case is totally necessary but yeah, we would have to add offloading/onloading around it if we decide to keep it


else:
# If updating the module parameter data, assumes this is also the transform
# data
if name is None:
raise ValueError("Name and module are required to update parma data")
update_parameter_data(module, data, name)
# is already registered/shared data
update_parameter_data(module, data, self.transform_name)

if self.permutation is not None and permutation_data is not None:
update_parameter_data(module, permutation_data, self.permutation_name)

def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Expand Down
30 changes: 23 additions & 7 deletions src/compressed_tensors/transforms/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@
import torch
from compressed_tensors.transforms import Transforms
from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix
from compressed_tensors.transforms.utils import apply_matrix_transform
from compressed_tensors.transforms.utils import (
SingletonMatrixRegistry,
apply_matrix_transform,
)


@Transforms.register("hadamard")
class Hadamard(Transforms):
def __init__(
self,
size: int,
transform_name: str,
empty: Optional[bool] = False,
device: Optional[Union[str, torch.device]] = "cuda",
device: Optional[Union[str, torch.device]] = "cpu",
dtype: Optional[torch.dtype] = torch.bfloat16,
*args,
**kwargs,
Expand All @@ -46,13 +50,25 @@ def __init__(
:param dtype: type to cast the rotation matrix to

"""
if not empty:
# TODO: this is deterministic; we should just serialize the size
transform = torch.Tensor(deterministic_hadamard_matrix(size=size))
self.matrix_registry = SingletonMatrixRegistry()
self.size = size

if empty:
transform = torch.empty((size, size)).to(dtype)
else:
transform = torch.empty((size, size))
transform = self.fetch().to(dtype).to(device)

super().__init__(transform=transform, transform_name=transform_name)

if not self.matrix_registry.contains(size):
self.matrix_registry.set_matrix(size, self.transform)

super().__init__(transform=transform, dtype=dtype, device=device)
def fetch(self):
# TODO: this is deterministic; we should just serialize the size
transform = self.matrix_registry.get_matrix(self.size)
if transform is None:
transform = torch.Tensor(deterministic_hadamard_matrix(size=self.size))
return transform

def apply(
self,
Expand Down
12 changes: 9 additions & 3 deletions src/compressed_tensors/transforms/hadamard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import torch


__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
__all__ = [
"random_hadamard_matrix",
"deterministic_hadamard_matrix",
"SingletonHadamardRegistry",
]

# adapted from:
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
Expand Down Expand Up @@ -59,6 +63,7 @@ def deterministic_hadamard_matrix(size: int):
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master


# ToDo: should no longer be random, call something else --> different generation type than scipy?
def random_hadamard_matrix(size: int) -> torch.Tensor:
"""
Produces a randomly generated Hadamard matrix.
Expand All @@ -73,7 +78,8 @@ def random_hadamard_matrix(size: int) -> torch.Tensor:
# the matrix generated to be reproducible

# Benefits: support other shapes / non powers of 2, support randomization
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
# Q = torch.randint(low=1, high=2, size=(size,)).to(torch.float64)
Q = torch.ones(size).to(torch.float64)
Q = Q * 2 - 1
Q = torch.diag(Q)
return _matmul_hadU(Q)
Expand Down Expand Up @@ -129,7 +135,7 @@ def _matmul_hadU(X, transpose=False):
input = hadK.view(1, K, K).to(input) @ input

# normalize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment might need to go too if we are not normalizing?

return input.view(X.shape) / torch.tensor(n).sqrt()
return input.view(X.shape)


def _is_pow2(n):
Expand Down
53 changes: 51 additions & 2 deletions src/compressed_tensors/transforms/matrix_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import torch
from compressed_tensors.transforms import Transforms
from compressed_tensors.transforms.utils import apply_matrix_transform
from compressed_tensors.transforms.utils import (
SingletonMatrixRegistry,
apply_matrix_transform,
)


# TODO: fix loading
@Transforms.register("matrix-mul")
class MatrixMultiply(Transforms):
def __init__(
self,
name: str,
transform_name: str,
transform_data: Optional[torch.Tensor] = None,
size: Optional[int] = None,
empty: Optional[bool] = False,
device: Optional[Union[str, torch.device]] = "cpu",
dtype: Optional[torch.dtype] = torch.bfloat16,
*args,
**kwargs,
):

if empty and size is None:
raise ValueError(
"size is required when setting up parameters for deserialization "
)

if not empty and transform_data is None:
raise ValueError(
"transform_data is required when initializing matrix-multiply transforms"
)

# name required to either pull a cached matrix or cache a matrix itself
# will assume each name corresponds to a unique matrix
self.name = name
self.matrix_registry = SingletonMatrixRegistry()

# Can we get rid of the size for deserialization?
if empty:
transform = torch.empty((size, size)).to(dtype)
else:
transform = self.fetch(transform_data).to(dtype).to(device)

super().__init__(transform=transform, transform_name=tranform_name)

if not self.matrix_registry.contains(self.name):
self.matrix_registry.set_matrix(self.name, self.transform)

def fetch(self, transform_data: torch.Tensor):
transform = self.matrix_registry.get_matrix(self.name)
if transform is None:
return transform_data
return transform

def apply(
self,
input_tensor: torch.Tensor,
Expand Down
Loading