-
Notifications
You must be signed in to change notification settings - Fork 131
Implement pad
#748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Implement pad
#748
Changes from 29 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
91311d5
Refactor linspace, logspace, and geomspace to match numpy implementation
jessegrabowski a98b8ae
Add `pt.pad`
jessegrabowski 02566b6
Use subclassed `OpFromGraph` to represent `pad` Op
jessegrabowski 32b14d2
Add test for `flip`
jessegrabowski 3e65827
Address reviewer feedback
jessegrabowski c2b8465
Remove `inplace` argument to `set_subtensor`
jessegrabowski 8f213a3
Delay setting dtype of `xspace` Ops until after all computation to ma…
jessegrabowski be6ed82
Use `shape_padright` instead of `.reshape` tricks
jessegrabowski 9c76a8f
Add test for `dtype` kwarg on `xspace` Ops
jessegrabowski eeb9fa3
Save keyword arguments in `Pad` `OpFromGraph`
jessegrabowski c28faaa
Add test for arbitrary padding at higher dimensions
jessegrabowski ab99a1e
First draft JAX overload
jessegrabowski e93fa56
Expect symbolic `num` argument
jessegrabowski 2d55a06
Split `wrap_pad` into separate function; eliminate use of `scan`
jessegrabowski 48037c6
<<DO NOT MERGE>> testing notebook
jessegrabowski b4ffdc5
Add `reflect` and `symmetric` padding
jessegrabowski 6808648
Remove test notebook
jessegrabowski 0f9ca38
Correct reflect and symmetric implementations
jessegrabowski c1dd0bc
Fix docs, JAX test
jessegrabowski 3b34779
Remove `_broadcast_inputs` helper, update docstrings
jessegrabowski 4cd9702
Remove `OpFromGraph` and associated `JAX` dispatch
jessegrabowski dbda326
Revert "Remove `OpFromGraph` and associated `JAX` dispatch"
jessegrabowski 32ae3eb
Add issue link to `reflect_type` error message
jessegrabowski d543ed6
Move `flip` to `tensor/subtensor.py`, add docstring
jessegrabowski ba6c613
Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor…
jessegrabowski aa95403
Appease mypy
jessegrabowski 2a11ffa
Appease mypy, add docstring to `pad`
jessegrabowski 89f8cdf
Appease mypy, add docstring to `pad`
jessegrabowski 0f5e2ba
Fix doctests
jessegrabowski 74e39b2
Fix doctests
jessegrabowski d169928
Fix doctests
jessegrabowski e429bab
Propagate all optional arguments to JAX
jessegrabowski 8722bfe
Propagate all optional arguments to JAX
jessegrabowski 3f8a377
Appease mypy
jessegrabowski 163b441
Test `NUMBA` backend
jessegrabowski 2c9f727
I love mypy
jessegrabowski 19ed2c0
Skip failing numba test
jessegrabowski bbeb300
Merge branch 'main' into pad
jessegrabowski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import jax.numpy as jnp | ||
|
||
from pytensor.link.jax.dispatch import jax_funcify | ||
from pytensor.tensor.pad import Pad | ||
|
||
|
||
fixed_kwargs = {"reflect": ["reflect_type"], "symmetric": ["reflect_type"]} | ||
|
||
|
||
@jax_funcify.register(Pad) | ||
def jax_funcify_pad(op, **kwargs): | ||
pad_mode = op.pad_mode | ||
expected_kwargs = fixed_kwargs.get(pad_mode, {}) | ||
mode_kwargs = {kwarg: getattr(op, kwarg) for kwarg in expected_kwargs} | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def pad(x, pad_width, *args): | ||
print(args) | ||
return jnp.pad(x, pad_width, mode=pad_mode, **mode_kwargs) | ||
|
||
return pad |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.