13
13
from functools import partial
14
14
from typing import Any
15
15
16
+ import numpy .typing as npt
17
+
16
18
import torch
17
19
from botorch .optim .utils import (
18
20
_handle_numerical_errors ,
21
23
)
22
24
from botorch .optim .utils .numpy_utils import as_ndarray
23
25
from botorch .utils .context_managers import zero_grad_ctx
24
- from numpy import float64 as np_float64 , full as np_full , ndarray , zeros as np_zeros
26
+ from numpy import float64 as np_float64 , full as np_full , zeros as np_zeros
25
27
from torch import Tensor
26
28
27
29
@@ -82,10 +84,10 @@ def __init__(
82
84
self ,
83
85
closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]],
84
86
parameters : dict [str , Tensor ],
85
- as_array : Callable [[Tensor ], ndarray ] = None , # pyre-ignore [9]
86
- as_tensor : Callable [[ndarray ], Tensor ] = torch .as_tensor ,
87
- get_state : Callable [[], ndarray ] = None , # pyre-ignore [9]
88
- set_state : Callable [[ndarray ], None ] = None , # pyre-ignore [9]
87
+ as_array : Callable [[Tensor ], npt . NDArray ] = None , # pyre-ignore [9]
88
+ as_tensor : Callable [[npt . NDArray ], Tensor ] = torch .as_tensor ,
89
+ get_state : Callable [[], npt . NDArray ] = None , # pyre-ignore [9]
90
+ set_state : Callable [[npt . NDArray ], None ] = None , # pyre-ignore [9]
89
91
fill_value : float = 0.0 ,
90
92
persistent : bool = True ,
91
93
) -> None :
@@ -140,11 +142,11 @@ def __init__(
140
142
141
143
self .fill_value = fill_value
142
144
self .persistent = persistent
143
- self ._gradient_ndarray : ndarray | None = None
145
+ self ._gradient_ndarray : npt . NDArray | None = None
144
146
145
147
def __call__ (
146
- self , state : ndarray | None = None , ** kwargs : Any
147
- ) -> tuple [ndarray , ndarray ]:
148
+ self , state : npt . NDArray | None = None , ** kwargs : Any
149
+ ) -> tuple [npt . NDArray , npt . NDArray ]:
148
150
if state is not None :
149
151
self .state = state
150
152
@@ -164,14 +166,14 @@ def __call__(
164
166
return value , grads
165
167
166
168
@property
167
- def state (self ) -> ndarray :
169
+ def state (self ) -> npt . NDArray :
168
170
return self ._get_state ()
169
171
170
172
@state .setter
171
- def state (self , state : ndarray ) -> None :
173
+ def state (self , state : npt . NDArray ) -> None :
172
174
self ._set_state (state )
173
175
174
- def _get_gradient_ndarray (self , fill_value : float | None = None ) -> ndarray :
176
+ def _get_gradient_ndarray (self , fill_value : float | None = None ) -> npt . NDArray :
175
177
if self .persistent and self ._gradient_ndarray is not None :
176
178
if fill_value is not None :
177
179
self ._gradient_ndarray .fill (fill_value )
0 commit comments