Skip to content

[Transforms] Transform Args, Scheme, and Config #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented May 21, 2025

Summary

  • Introduce Transform Args, Scheme, Config and Data as structures to define recipes to apply transforms
  • When compared to [Transforms] Transform Arg, Scheme, and Data support #275, the requirement for ModuleTarget, and TransformData has been removed as well as the transform_creation_args and call_args fields
  • Supports spinquant-like transform configs which target non-linear modules such as q_attn and k_cache

TransformArgs

  • Arguments which define how and where a transform should be applied to a model

TransformScheme

  • Scheme used to parameterize a particular transform type and specify how and where it should be applied to the model
  • Includes a list of TransformationArgs onto which the transformation will be applied

TransformConfig

  • Configuration of transforms to be applied to a model. This config is to be serialized within a model's config.json file
  • The keys can be any arbitrary string and a TransformationScheme should be provided for each new transform type.
  • Include preset configs for QUIP/QUIP# and Llama-SpinQuant

Example:

from compressed_tensors.transform import TransformArgs, TransformScheme, TransformsConfig

forward_args = TransformArgs(targets=["Linear"], location="input")  # non-mergable
inverse_args = TransformArgs(targets=["Linear"], location="weight", side="right", inverse=True)

scheme = TransformScheme(
    type="hadamard",
    apply=[forward_args, inverse_args],
    randomize_modules=True
)

config = TransformsConfig(transform_groups={"v_transform": scheme}
# [transform locations] and {quant locations}:
# input  <- {input @ [input]}
# weight <- {[weight_output] @ weight @ [weight_input]}
# output <- {input @ weight.T @ [output]}

Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs marked this pull request as ready for review May 21, 2025 13:41
Signed-off-by: Kyle Sayers <[email protected]>
Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! a few comments

to a particular model. The keys can be any arbitrary string
"""

transform_groups: Dict[str, TransformScheme]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since this lives in TransformsConfig, the name prefix is a little redundant

Suggested change
transform_groups: Dict[str, TransformScheme]
groups: Dict[str, TransformScheme]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I prefer groups over config_groups as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer groups too!

type: str
apply: List[TransformArgs] = Field(default_factory=list)
randomize_modules: bool = Field(default=False)
requires_grad: bool = Field(default=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does requires_grad need to be True anywhere?

Copy link
Contributor Author

@kylesayrs kylesayrs May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this answer your question?

:param requires_grad: True if weights include gradients for training

Copy link
Contributor

@brian-dellabetta brian-dellabetta May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol kyle. i don't think gradients are needed for spinquant, are they needed for quip?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quip doesn't specify learned transforms, while spinquant does. However, in the current implementation you can still learn transforms with a quip-style config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, if i understand correctly, the updates in spinquant aren't calculated with backprop, they have their own loss/update function to be able to do a Cayley SGD. so we might not need this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quickly adding in here since this is an important point and something we've been burned a lot with in the past particularly around not being able to support QAT. We should always have an option for controlling requires_grad for any parameters we inject, so we can pass the decision of how to optimize to the caller and not overindex on current use cases or assumptions. For SpinQuant specifically, a PyTorch optimizer could definitely be set up to optimize these parameters according to the Cayley SGD equations since it actively needs the gradient information, and PyTorch would then handle calculations so the implementer doesn't have to.

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions:

  1. How do we maintain the order of the transforms if multiple are being applied to a given input or weight?
  2. What utility will be inferring the transform size and if its right/left?

"""

targets: List[str]
location: Literal["input", "weight", "output", "k_cache", "q_attn"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enum?

@kylesayrs kylesayrs marked this pull request as draft May 21, 2025 20:01
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs
Copy link
Contributor Author

kylesayrs commented May 21, 2025

@dsikka

  1. If you're only talking about matrix transforms, order does not matter because matrix multiplication is commutative
  2. FYI @anmarques suggested replacing side ∈ {"left", "right"} with side ∈ {"input", "output"} to clarify how the weight is applied. This is especially nice for quip, since you can directly see the U and V shapes align in the config

To answer the question, this function uses location and side to determine which of the weight to apply to:
https://github.com/neuralmagic/compressed-tensors/pull/316/files#diff-d70cb4816d5b1e1ff0e5f5ea117da8d2e1e106b7419680ba41228a5064dad890R32

@kylesayrs kylesayrs marked this pull request as ready for review May 21, 2025 21:47
Signed-off-by: Kyle Sayers <[email protected]>
"""

targets: List[str]
location: Literal["input", "weight", "output", "k_cache", "q_attn"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think concepts like KV cache, attention, and other locally nested modules or objects would be better suited under targets, especially with the new logic we're looking to add that supports injecting and targeting the local namespace within a module's forward function. My core concerns are around these 3 main pieces:

  1. Keeps target_names clean and focused by avoiding a growing list of special cases and centering this interface on core PyTorch concepts (inputs, outputs, parameters). The complexity and understanding instead move into the target's logic, which already has some elements specifically targeting module namespaces.
  2. Improves clarity and generality by making it explicit what can be selected from each module, function, or op, letting algorithms specify what to target rather than relying on implicit assumptions that may vary across implementations and require knowledge of the architectures within compressed tensors. This is especially important for supporting things like KV cache quantization and quantized attention, where there's active and changing research in addition to hardware-dependent pieces towards quantizing attention inputs vs. just the KV cache, or pre- vs. post-RoPE quantization for best accuracy.
  3. It simplifies extensibility that we could ideally easily extend location to represent other parameter names or targeted portions of the input/output if needed.


targets: List[str]
location: Literal["input", "weight", "output", "k_cache", "q_attn"]
side: Optional[Literal["input", "output"]] = Field(default=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using "input" and "output" for side overloads those terms, especially since they already exist in location and are reserved, general concepts in PyTorch’s namespace and documentation. In the context of an operation (as with location), "input" and "output" make sense because there's a clear computation graph: inputs are processed into outputs. But it can be ambiguous for a tensor or parameter without an explicit operation: does "input" mean the data going into a transform, or the left side of a multiplication? So we have to rely on more documentation to communicate the subject properly. In contrast, left and right matrix multiplies are standard linear algebra concepts, and PyTorch's underlying operations remain consistent.

return value

@model_validator(mode="after")
def determine_side(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should restrict here, as there definitely could be arguments for doing a left or right multiply based on the derivation of the algorithm in the future. Ideally that restriction would come from the algorithm implementation/definition in the recipe in LLM compressor



# quip / quip sharp
QUIP = TransformConfig(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to define these within compressed tensors for the algorithms? I think it's going to be tricky to keep them here since they can depend on the input model's architecture and the user's intentions (particularly around targeting layers and granularity). All of that is more easily controlled in LLM Compressor and ideally is what LLM Compressor is for to define the algorithms, optimize them, and pass those into compressed tensors to make the edits. The original goal for compressed tensors was to remain model-agnostic so we didn't have to import Transformers or anything else and focus on just edits to the graph.

type: str
apply: List[TransformArgs] = Field(default_factory=list)
randomize_modules: bool = Field(default=False)
requires_grad: bool = Field(default=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quickly adding in here since this is an important point and something we've been burned a lot with in the past particularly around not being able to support QAT. We should always have an option for controlling requires_grad for any parameters we inject, so we can pass the decision of how to optimize to the caller and not overindex on current use cases or assumptions. For SpinQuant specifically, a PyTorch optimizer could definitely be set up to optimize these parameters according to the Cayley SGD equations since it actively needs the gradient information, and PyTorch would then handle calculations so the implementer doesn't have to.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants