Skip to content

Commit 340f02c

Browse files
jackson-tsang578pytorchmergebot
authored andcommitted
make it clearer (in docs) one can double decorate with torch.library.impl_* APIs (pytorch#137608)
Fixes pytorch#120503. Fix originally attempt by @soxand16 with PR: pytorch#121469. PR was almost ready to merge, but then went stale (over 6 months old). This PR implements original fix with refactoring for clarity. CC: @zou3519 Pull Request resolved: pytorch#137608 Approved by: https://github.com/zou3519
1 parent 6bbbb08 commit 340f02c

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

torch/library.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,10 @@ def impl(qualname, types, func=None, *, lib=None):
538538
Please only use this if the implementation truly supports all device types;
539539
for example, this is true if it is a composition of built-in PyTorch operators.
540540
541+
This API may be used as a decorator. You can use nested decorators
542+
with this API provided they return a function and are placed inside
543+
this API (see Example 2).
544+
541545
Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
542546
543547
Args:
@@ -549,7 +553,7 @@ def impl(qualname, types, func=None, *, lib=None):
549553
Examples:
550554
>>> import torch
551555
>>> import numpy as np
552-
>>>
556+
>>> # Example 1: Register function.
553557
>>> # Define the operator
554558
>>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
555559
>>>
@@ -561,6 +565,28 @@ def impl(qualname, types, func=None, *, lib=None):
561565
>>> x = torch.randn(3)
562566
>>> y = torch.ops.mylib.mysin(x)
563567
>>> assert torch.allclose(y, x.sin())
568+
>>>
569+
>>> # Example 2: Register function with decorator.
570+
>>> def custom_decorator(func):
571+
>>> def wrapper(*args, **kwargs):
572+
>>> return func(*args, **kwargs) + 1
573+
>>> return wrapper
574+
>>>
575+
>>> # Define the operator
576+
>>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor")
577+
>>>
578+
>>> # Add implementations for the operator
579+
>>> @torch.library.impl("mylib::sin_plus_one", "cpu")
580+
>>> @custom_decorator
581+
>>> def f(x):
582+
>>> return torch.from_numpy(np.sin(x.numpy()))
583+
>>>
584+
>>> # Call the new operator from torch.ops.
585+
>>> x = torch.randn(3)
586+
>>>
587+
>>> y1 = torch.ops.mylib.sin_plus_one(x)
588+
>>> y2 = torch.sin(x) + 1
589+
>>> assert torch.allclose(y1, y2)
564590
"""
565591
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
566592

0 commit comments

Comments
 (0)