-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Raise NotImplementedError for SplineWrapper gradient operation #2211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d3d947e
4d510f4
a85bd3e
532235a
8fab095
b31d6f3
5da1ce2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -378,6 +378,18 @@ class SplineWrapper (theano.Op): | |
def __init__(self, spline): | ||
self.spline = spline | ||
|
||
@property | ||
def grad_op(self): | ||
if not hasattr(self, '_grad_op'): | ||
try: | ||
self._grad_op = SplineWrapper(self.spline.derivative()) | ||
except ValueError: | ||
self._grad_op = None | ||
|
||
if self._grad_op is None: | ||
raise NotImplementedError('Spline of order 0 is not differentiable') | ||
return self._grad_op | ||
|
||
def perform(self, node, inputs, output_storage): | ||
x, = inputs | ||
output_storage[0][0] = np.asarray(self.spline(x)) | ||
|
@@ -386,13 +398,4 @@ def grad(self, inputs, grads): | |
x, = inputs | ||
x_grad, = grads | ||
|
||
if not hasattr(self, 'grad_op'): | ||
try: | ||
self.grad_op = SplineWrapper(self.spline.derivative()) | ||
except ValueError: | ||
self.grad_op = None | ||
|
||
if self.grad_op is None: | ||
raise NotImplementedError('Spline of order 0 is not differentiable') | ||
else: | ||
return [x_grad * self.grad_op(x)] | ||
return [x_grad * self.grad_op(x)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can create new op right here. Pure theano code is expected for grad method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid O(n) calculations on each call of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can memorize the call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Memorize op creation to be more accurate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I thought about it. But I'm concerned that in this case gradient calculation time becomes non-deterministic. For example it might significantly bias There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? Functions are compiled after graph is constructed. That will not affect runtime There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, you are right. I added lazy creation of the derivatives. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!