-
Notifications
You must be signed in to change notification settings - Fork 11
[Transforms] Transform Args, Scheme, and Config #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! a few comments
to a particular model. The keys can be any arbitrary string | ||
""" | ||
|
||
transform_groups: Dict[str, TransformScheme] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: since this lives in TransformsConfig
, the name prefix is a little redundant
transform_groups: Dict[str, TransformScheme] | |
groups: Dict[str, TransformScheme] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I prefer groups over config_groups
as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer groups too!
type: str | ||
apply: List[TransformArgs] = Field(default_factory=list) | ||
randomize_modules: bool = Field(default=False) | ||
requires_grad: bool = Field(default=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does requires_grad
need to be True anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this answer your question?
:param requires_grad: True if weights include gradients for training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol kyle. i don't think gradients are needed for spinquant, are they needed for quip?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quip doesn't specify learned transforms, while spinquant does. However, in the current implementation you can still learn transforms with a quip-style config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, if i understand correctly, the updates in spinquant aren't calculated with backprop, they have their own loss/update function to be able to do a Cayley SGD. so we might not need this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quickly adding in here since this is an important point and something we've been burned a lot with in the past particularly around not being able to support QAT. We should always have an option for controlling requires_grad for any parameters we inject, so we can pass the decision of how to optimize to the caller and not overindex on current use cases or assumptions. For SpinQuant specifically, a PyTorch optimizer could definitely be set up to optimize these parameters according to the Cayley SGD equations since it actively needs the gradient information, and PyTorch would then handle calculations so the implementer doesn't have to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two questions:
- How do we maintain the order of the transforms if multiple are being applied to a given input or weight?
- What utility will be inferring the transform size and if its right/left?
""" | ||
|
||
targets: List[str] | ||
location: Literal["input", "weight", "output", "k_cache", "q_attn"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enum?
Signed-off-by: Kyle Sayers <[email protected]>
To answer the question, this function uses location and side to determine which of the weight to apply to: |
Signed-off-by: Kyle Sayers <[email protected]>
""" | ||
|
||
targets: List[str] | ||
location: Literal["input", "weight", "output", "k_cache", "q_attn"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think concepts like KV cache, attention, and other locally nested modules or objects would be better suited under targets, especially with the new logic we're looking to add that supports injecting and targeting the local namespace within a module's forward function. My core concerns are around these 3 main pieces:
- Keeps target_names clean and focused by avoiding a growing list of special cases and centering this interface on core PyTorch concepts (inputs, outputs, parameters). The complexity and understanding instead move into the target's logic, which already has some elements specifically targeting module namespaces.
- Improves clarity and generality by making it explicit what can be selected from each module, function, or op, letting algorithms specify what to target rather than relying on implicit assumptions that may vary across implementations and require knowledge of the architectures within compressed tensors. This is especially important for supporting things like KV cache quantization and quantized attention, where there's active and changing research in addition to hardware-dependent pieces towards quantizing attention inputs vs. just the KV cache, or pre- vs. post-RoPE quantization for best accuracy.
- It simplifies extensibility that we could ideally easily extend location to represent other parameter names or targeted portions of the input/output if needed.
|
||
targets: List[str] | ||
location: Literal["input", "weight", "output", "k_cache", "q_attn"] | ||
side: Optional[Literal["input", "output"]] = Field(default=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using "input" and "output" for side overloads those terms, especially since they already exist in location and are reserved, general concepts in PyTorch’s namespace and documentation. In the context of an operation (as with location), "input" and "output" make sense because there's a clear computation graph: inputs are processed into outputs. But it can be ambiguous for a tensor or parameter without an explicit operation: does "input" mean the data going into a transform, or the left side of a multiplication? So we have to rely on more documentation to communicate the subject properly. In contrast, left and right matrix multiplies are standard linear algebra concepts, and PyTorch's underlying operations remain consistent.
return value | ||
|
||
@model_validator(mode="after") | ||
def determine_side(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should restrict here, as there definitely could be arguments for doing a left or right multiply based on the derivation of the algorithm in the future. Ideally that restriction would come from the algorithm implementation/definition in the recipe in LLM compressor
|
||
|
||
# quip / quip sharp | ||
QUIP = TransformConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the plan to define these within compressed tensors for the algorithms? I think it's going to be tricky to keep them here since they can depend on the input model's architecture and the user's intentions (particularly around targeting layers and granularity). All of that is more easily controlled in LLM Compressor and ideally is what LLM Compressor is for to define the algorithms, optimize them, and pass those into compressed tensors to make the edits. The original goal for compressed tensors was to remain model-agnostic so we didn't have to import Transformers or anything else and focus on just edits to the graph.
type: str | ||
apply: List[TransformArgs] = Field(default_factory=list) | ||
randomize_modules: bool = Field(default=False) | ||
requires_grad: bool = Field(default=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quickly adding in here since this is an important point and something we've been burned a lot with in the past particularly around not being able to support QAT. We should always have an option for controlling requires_grad for any parameters we inject, so we can pass the decision of how to optimize to the caller and not overindex on current use cases or assumptions. For SpinQuant specifically, a PyTorch optimizer could definitely be set up to optimize these parameters according to the Cayley SGD equations since it actively needs the gradient information, and PyTorch would then handle calculations so the implementer doesn't have to.
Summary
ModuleTarget
, andTransformData
has been removed as well as thetransform_creation_args
andcall_args
fieldsq_attn
andk_cache
TransformArgs
TransformScheme
TransformationArgs
onto which the transformation will be appliedTransformConfig
config.json
fileExample: