diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c2c5a704..255754df 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -62,6 +62,7 @@ ] from compressed_tensors.quantization.utils.helpers import is_module_quantized +from compressed_tensors.utils import has_offloaded_params from compressed_tensors.utils.safetensors_load import ( get_quantization_state_dict, get_weight_mappings, @@ -78,7 +79,7 @@ def load_transforms(model: Module, model_name_or_path: str): state_dict = {} for weight_name, safe_path in weight_mappings.items(): - if "transform" in weight_name: + if "transform" in weight_name or "_perm_" in weight_name: with safe_open(safe_path, framework="pt", device="cpu") as f: state_dict[weight_name] = f.get_tensor(weight_name) @@ -88,11 +89,14 @@ def load_transforms(model: Module, model_name_or_path: str): if transform_data: for transform_name, transform_values in transform_data.data.items(): full_name = f"{name}.{transform_name}" - transform_data = state_dict.get(full_name, None) + full_per_name = full_name.replace("transform", "perm") + dict_data = state_dict.get(full_name, None) + permutation_data = state_dict.get(full_per_name, None) transform = transform_values.get("transform") - transform.register_to_module(name=transform_name, module=submodule) + transform.update_device(module=submodule) + transform.register_to_module(module=submodule) transform.update_transform( - module=submodule, data=transform_data, name=transform_name + module=submodule, data=dict_data, permutation_data=permutation_data ) @@ -180,6 +184,7 @@ def process_transforms_config( # only support weight parameters for now, assume one value in # module targets transform_name = f"{module_targets[0]}_transform_{idx}" + permutation_name = f"{module_targets[0]}_perm_{idx}" # create an empty tensor OR create a new transform dtype = getattr(submodule, module_targets[0]).dtype @@ -191,17 +196,21 @@ def process_transforms_config( transform_type, dtype=dtype, empty=True, + transform_name=transform_name, + permutation_name=permutation_name, **transform_creation_args, ) else: transform = Transforms.load_from_registry( transform_type, dtype=dtype, + transform_name=transform_name, + permutation_name=permutation_name, + device=next(submodule.parameters()).device, **transform_creation_args, ) - transform.register_to_module( - name=transform_name, module=submodule - ) + + transform.register_to_module(module=submodule) # add relevant transform data to the submodule as well data = { @@ -217,6 +226,7 @@ def process_transforms_config( else: transform_data = TransformData(data=OrderedDict(data)) submodule.transform_data = transform_data + # 10358 for now mib; 1/3 of memory if caching/sharing parameter data return model diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index 56adcb62..715c0625 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -33,11 +33,10 @@ class Transforms(RegistryMixin): def __init__( self, transform: torch.Tensor, - learnable: Optional[bool] = True, - device: Optional[Union[str, torch.device]] = "cuda", - dtype: Optional[torch.dtype] = torch.bfloat16, + transform_name: str, + permutation: Optional[torch.Tensor] = None, + permutation_name: Optional[str] = None, ): - self.learnable = learnable """ Base class for setting up transforms. The registry creates transforms as parameters which can be attached to modules. @@ -62,33 +61,47 @@ def __init__( :param transform: transform (e.g. torch.Tensor, scalar) to be applied """ - if self.learnable: - self.transform = torch.nn.Parameter(transform.to(dtype).to(device)) - else: - self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) + self.transform = torch.nn.Parameter(transform, requires_grad=False) + self.transform_name = transform_name + self.permutation = ( + torch.nn.Parameter(permutation, requires_grad=False) + if permutation is not None + else None + ) + self.permutation_name = permutation_name + + def update_device(self, module: torch.nn.Module): + # Helper function required for deserialization + module_device = next(module.parameters()).device + self.transform.data = self.transform.data.to(module_device) + if self.permutation is not None: + self.permutation.data = self.permutation.data.to(module_device) # register to class for easy offloading, serialization, deserialization - def register_to_module(self, name: str, module: torch.nn.Module): - if self.learnable: - register_offload_parameter(module, name, self.transform) - else: - # TODO: have to verify serialization/offloading - module.register_buffer(name, self.transform) + # TODO: Manage lifecycle of permutation and transform buffers + def register_to_module(self, module: torch.nn.Module): + register_offload_parameter(module, self.transform_name, self.transform) + if self.permutation is not None: + register_offload_parameter(module, self.permutation_name, self.permutation) def update_transform( self, data: torch.Tensor, + permutation_data: Optional[torch.Tensor] = None, module: Optional[torch.nn.Module] = None, - name: Optional[str] = None, ): if module is None: self.transform.data.copy_(data) + if self.permutation is not None and permutation_data is not None: + self.permutation.data.copy_(permutation_data) + else: # If updating the module parameter data, assumes this is also the transform - # data - if name is None: - raise ValueError("Name and module are required to update parma data") - update_parameter_data(module, data, name) + # is already registered/shared data + update_parameter_data(module, data, self.transform_name) + + if self.permutation is not None and permutation_data is not None: + update_parameter_data(module, permutation_data, self.permutation_name) def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py index ef0e27a4..7f280e55 100644 --- a/src/compressed_tensors/transforms/hadamard.py +++ b/src/compressed_tensors/transforms/hadamard.py @@ -17,7 +17,10 @@ import torch from compressed_tensors.transforms import Transforms from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix -from compressed_tensors.transforms.utils import apply_matrix_transform +from compressed_tensors.transforms.utils import ( + SingletonMatrixRegistry, + apply_matrix_transform, +) @Transforms.register("hadamard") @@ -25,8 +28,9 @@ class Hadamard(Transforms): def __init__( self, size: int, + transform_name: str, empty: Optional[bool] = False, - device: Optional[Union[str, torch.device]] = "cuda", + device: Optional[Union[str, torch.device]] = "cpu", dtype: Optional[torch.dtype] = torch.bfloat16, *args, **kwargs, @@ -46,13 +50,25 @@ def __init__( :param dtype: type to cast the rotation matrix to """ - if not empty: - # TODO: this is deterministic; we should just serialize the size - transform = torch.Tensor(deterministic_hadamard_matrix(size=size)) + self.matrix_registry = SingletonMatrixRegistry() + self.size = size + + if empty: + transform = torch.empty((size, size)).to(dtype) else: - transform = torch.empty((size, size)) + transform = self.fetch().to(dtype).to(device) + + super().__init__(transform=transform, transform_name=transform_name) + + if not self.matrix_registry.contains(size): + self.matrix_registry.set_matrix(size, self.transform) - super().__init__(transform=transform, dtype=dtype, device=device) + def fetch(self): + # TODO: this is deterministic; we should just serialize the size + transform = self.matrix_registry.get_matrix(self.size) + if transform is None: + transform = torch.Tensor(deterministic_hadamard_matrix(size=self.size)) + return transform def apply( self, diff --git a/src/compressed_tensors/transforms/hadamard_utils.py b/src/compressed_tensors/transforms/hadamard_utils.py index 2cbd74d8..de9364cd 100644 --- a/src/compressed_tensors/transforms/hadamard_utils.py +++ b/src/compressed_tensors/transforms/hadamard_utils.py @@ -18,7 +18,11 @@ import torch -__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] +__all__ = [ + "random_hadamard_matrix", + "deterministic_hadamard_matrix", + "SingletonHadamardRegistry", +] # adapted from: # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py @@ -59,6 +63,7 @@ def deterministic_hadamard_matrix(size: int): # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master +# ToDo: should no longer be random, call something else --> different generation type than scipy? def random_hadamard_matrix(size: int) -> torch.Tensor: """ Produces a randomly generated Hadamard matrix. @@ -73,7 +78,8 @@ def random_hadamard_matrix(size: int) -> torch.Tensor: # the matrix generated to be reproducible # Benefits: support other shapes / non powers of 2, support randomization - Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + # Q = torch.randint(low=1, high=2, size=(size,)).to(torch.float64) + Q = torch.ones(size).to(torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) return _matmul_hadU(Q) @@ -129,7 +135,7 @@ def _matmul_hadU(X, transpose=False): input = hadK.view(1, K, K).to(input) @ input # normalize - return input.view(X.shape) / torch.tensor(n).sqrt() + return input.view(X.shape) def _is_pow2(n): diff --git a/src/compressed_tensors/transforms/matrix_multiply.py b/src/compressed_tensors/transforms/matrix_multiply.py index a06d61f2..5216fd6e 100644 --- a/src/compressed_tensors/transforms/matrix_multiply.py +++ b/src/compressed_tensors/transforms/matrix_multiply.py @@ -12,14 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union + import torch from compressed_tensors.transforms import Transforms -from compressed_tensors.transforms.utils import apply_matrix_transform +from compressed_tensors.transforms.utils import ( + SingletonMatrixRegistry, + apply_matrix_transform, +) -# TODO: fix loading @Transforms.register("matrix-mul") class MatrixMultiply(Transforms): + def __init__( + self, + name: str, + transform_name: str, + transform_data: Optional[torch.Tensor] = None, + size: Optional[int] = None, + empty: Optional[bool] = False, + device: Optional[Union[str, torch.device]] = "cpu", + dtype: Optional[torch.dtype] = torch.bfloat16, + *args, + **kwargs, + ): + + if empty and size is None: + raise ValueError( + "size is required when setting up parameters for deserialization " + ) + + if not empty and transform_data is None: + raise ValueError( + "transform_data is required when initializing matrix-multiply transforms" + ) + + # name required to either pull a cached matrix or cache a matrix itself + # will assume each name corresponds to a unique matrix + self.name = name + self.matrix_registry = SingletonMatrixRegistry() + + # Can we get rid of the size for deserialization? + if empty: + transform = torch.empty((size, size)).to(dtype) + else: + transform = self.fetch(transform_data).to(dtype).to(device) + + super().__init__(transform=transform, transform_name=tranform_name) + + if not self.matrix_registry.contains(self.name): + self.matrix_registry.set_matrix(self.name, self.transform) + + def fetch(self, transform_data: torch.Tensor): + transform = self.matrix_registry.get_matrix(self.name) + if transform is None: + return transform_data + return transform + def apply( self, input_tensor: torch.Tensor, diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py index 162269c5..20e52fda 100644 --- a/src/compressed_tensors/transforms/random_hadamard.py +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -12,22 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Optional, Union import torch from compressed_tensors.transforms import Transforms from compressed_tensors.transforms.hadamard_utils import random_hadamard_matrix -from compressed_tensors.transforms.utils import apply_matrix_transform +from compressed_tensors.transforms.utils import ( + SingletonMatrixRegistry, + apply_matrix_transform, +) +# TODO: allow randomness for both potentially, separate by generation type +# this will make randomness a creation arg instead @Transforms.register("random-hadamard") class RandomHadamard(Transforms): def __init__( self, size: int, + transform_name: str, + permutation_name: str, empty: Optional[bool] = False, - device: Optional[Union[str, torch.device]] = "cuda", + device: Optional[Union[str, torch.device]] = "cpu", dtype: Optional[torch.dtype] = torch.bfloat16, + *args, + **kwargs, ): """ Produces a randomly generated matrix with dims (size, size), with values @@ -52,13 +62,38 @@ def __init__( we will not have to store the entire matrix. Will need to consider accuracy implications. """ + self.size = size + self.normalized_size = math.sqrt(self.size) + # TODO: potentially lives outside of the registry + # And caching is controlled by llmcompressor + self.matrix_registry = SingletonMatrixRegistry() - if not empty: - transform = random_hadamard_matrix(size=size) + if empty: + transform = torch.empty((size, size)).to(dtype) + permutation = torch.empty((size)).to(dtype).to(device) else: - transform = torch.empty((size, size)) + transform = self.fetch().to(dtype).to(device) + permutation = ( + (torch.randint(low=0, high=2, size=(self.size,)) * 2 - 1) + .to(dtype) + .to(device) + ) - super().__init__(transform=transform, device=device, dtype=dtype) + super().__init__( + transform=transform, + permutation=permutation, + transform_name=transform_name, + permutation_name=permutation_name, + ) + + if not self.matrix_registry.contains(size): + self.matrix_registry.set_matrix(self.size, self.transform) + + def fetch(self): + transform = self.matrix_registry.get_matrix(self.size) + if transform is None: + transform = random_hadamard_matrix(size=self.size) + return transform def apply( self, @@ -67,7 +102,7 @@ def apply( first: bool = True, ) -> torch.Tensor: return apply_matrix_transform( - transform=self.transform, + transform=(self.transform * self.permutation) / self.normalized_size, input_tensor=input_tensor, transpose=transpose, first=first, @@ -92,7 +127,7 @@ def inverse_apply( transpose = not transpose return apply_matrix_transform( - transform=self.transform, + transform=(self.transform * self.permutation) / self.normalized_size, input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/src/compressed_tensors/transforms/temp.py b/src/compressed_tensors/transforms/temp.py new file mode 100644 index 00000000..1ed7643e --- /dev/null +++ b/src/compressed_tensors/transforms/temp.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transforms.hadamard_utils import ( + SingletonHadamardRegistry, + random_hadamard_matrix, +) + + +size = 2048 +dtype = torch.bfloat16 +hadamard_registry = SingletonHadamardRegistry() +deterministic_had = hadamard_registry.get_hadamard(size) +# fetch the deterministic had from the registry, if already precomputed +if deterministic_had is None: + deterministic_had = random_hadamard_matrix(size=size).to(dtype) + hadamard_registry.set_hadamard(size, deterministic_had) + +out = random_hadamard_matrix(size) +Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) +Q = Q * 2 - 1 + +breakpoint() +new_out = out * Q +new_out = new_out / torch.tensor(size).sqrt() +assert torch.equal(torch.eye(size), torch.round(new_out @ new_out.T)) +breakpoint() diff --git a/src/compressed_tensors/transforms/transform_scheme.py b/src/compressed_tensors/transforms/transform_scheme.py index f1770cc4..74d300bb 100644 --- a/src/compressed_tensors/transforms/transform_scheme.py +++ b/src/compressed_tensors/transforms/transform_scheme.py @@ -30,16 +30,15 @@ class TransformationScheme(BaseModel): :param groups: includes TransformationArgs containing the information about the layers that should be targeted by the specified transform. By providing a list, users have the ability to apply the same transform type (with the same set - of arguments) to different layers. + of arguments) to different layers. No + :param shared: if an identical transform is to be used for all the groups :param transform_creation_args: arguments needed to initialize the transform, if any - :param global_transform: whether an identical transform is applied to all the - TransformationArgs in the groups list """ transform_type: str groups: List[TransformationArgs] - global_transform: bool = False + shared: bool = False transform_creation_args: Optional[Dict[str, Any]] = None @field_validator("transform_type", mode="before") diff --git a/src/compressed_tensors/transforms/utils.py b/src/compressed_tensors/transforms/utils.py index 997c91f1..83364155 100644 --- a/src/compressed_tensors/transforms/utils.py +++ b/src/compressed_tensors/transforms/utils.py @@ -15,7 +15,27 @@ import torch -__all__ = ["apply_matrix_transform"] +__all__ = ["apply_matrix_transform", "SingletonMatrixRegistry"] + + +class SingletonMatrixRegistry: + _instance = None + + def __new__(cls): + # Check if the instance already exists + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._data = {} # Initialize the data storage + return cls._instance + + def set_matrix(self, key, value): + self._data[key] = value + + def get_matrix(self, key): + return self._data.get(key, None) + + def contains(self, key): + return key in self._data def apply_matrix_transform(