Skip to content

Full torch.func support #224

@frostedoyster

Description

@frostedoyster

Following #219, sphericart-torch operations already work with eager autograd, torch.compile, vmap, and jacfwd. What still does not work reliably are the reverse-mode torch.func transformations over the compiled custom-operation path: grad, jacrev, vmap(grad), and therefore hessian.

The core difficulty is that our backward formulas need to call custom derivative operations to get derivatives, so higher-order transforms end up tracing through custom ops inside backward, not just in forward. We added fake kernels and vmap rules, which is enough for batching and forward-mode, but it does not fully solve reverse-mode transformability through those "nested" custom calls.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions