77
88from collections import UserDict
99from contextlib import AbstractContextManager
10- from typing import (
11- TYPE_CHECKING ,
12- Any ,
13- Callable ,
14- Dict ,
15- List ,
16- Optional ,
17- Set ,
18- Tuple ,
19- Union ,
20- cast ,
21- overload ,
22- )
10+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , cast , overload
2311
2412import numpy as np
2513import theano .graph .basic
@@ -69,15 +57,15 @@ class _TraceDict(UserDict):
6957 ~~~~~~~~~~
7058 varnames: list of strings"""
7159
72- varnames : List [str ]
60+ varnames : list [str ]
7361 _len : int
7462 data : Point
7563
7664 def __init__ (
7765 self ,
78- point_list : Optional [ List [ Point ]] = None ,
79- multi_trace : Optional [ MultiTrace ] = None ,
80- dict_ : Optional [ Point ] = None ,
66+ point_list : list [ Point ] | None = None ,
67+ multi_trace : MultiTrace | None = None ,
68+ dict_ : Point | None = None ,
8169 ):
8270 """"""
8371 if multi_trace :
@@ -134,11 +122,11 @@ def apply_slice(arr: np.ndarray) -> np.ndarray:
134122 return _TraceDict (dict_ = sliced_dict )
135123
136124 @overload
137- def __getitem__ (self , item : Union [ str , HasName ] ) -> np .ndarray :
125+ def __getitem__ (self , item : str | HasName ) -> np .ndarray :
138126 ...
139127
140128 @overload
141- def __getitem__ (self , item : Union [ slice , int ] ) -> _TraceDict :
129+ def __getitem__ (self , item : slice | int ) -> _TraceDict :
142130 ...
143131
144132 def __getitem__ (self , item ):
@@ -155,13 +143,13 @@ def __getitem__(self, item):
155143
156144
157145def fast_sample_posterior_predictive (
158- trace : Union [ MultiTrace , Dataset , InferenceData , List [ Dict [str , np .ndarray ] ]],
159- samples : Optional [ int ] = None ,
160- model : Optional [ Model ] = None ,
161- var_names : Optional [ List [ str ]] = None ,
146+ trace : MultiTrace | Dataset | InferenceData | list [ dict [str , np .ndarray ]],
147+ samples : int | None = None ,
148+ model : Model | None = None ,
149+ var_names : list [ str ] | None = None ,
162150 keep_size : bool = False ,
163151 random_seed = None ,
164- ) -> Dict [str , np .ndarray ]:
152+ ) -> dict [str , np .ndarray ]:
165153 """Generate posterior predictive samples from a model given a trace.
166154
167155 This is a vectorized alternative to the standard ``sample_posterior_predictive`` function.
@@ -250,7 +238,7 @@ def fast_sample_posterior_predictive(
250238
251239 assert isinstance (_trace , _TraceDict )
252240
253- _samples : List [int ] = []
241+ _samples : list [int ] = []
254242 # temporary replacement for more complicated logic.
255243 max_samples : int = len_trace
256244 if samples is None or samples == max_samples :
@@ -289,7 +277,7 @@ def fast_sample_posterior_predictive(
289277 _ETPParent = UserDict
290278
291279 class _ExtendableTrace (_ETPParent ):
292- def extend_trace (self , trace : Dict [str , np .ndarray ]) -> None :
280+ def extend_trace (self , trace : dict [str , np .ndarray ]) -> None :
293281 for k , v in trace .items ():
294282 if k in self .data :
295283 self .data [k ] = np .concatenate ((self .data [k ], v ))
@@ -301,7 +289,7 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
301289 strace = _trace if s == len_trace else _trace [slice (0 , s )]
302290 try :
303291 values = posterior_predictive_draw_values (cast (List [Any ], vars ), strace , s )
304- new_trace : Dict [str , np .ndarray ] = {k .name : v for (k , v ) in zip (vars , values )}
292+ new_trace : dict [str , np .ndarray ] = {k .name : v for (k , v ) in zip (vars , values )}
305293 ppc_trace .extend_trace (new_trace )
306294 except KeyboardInterrupt :
307295 pass
@@ -313,8 +301,8 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
313301
314302
315303def posterior_predictive_draw_values (
316- vars : List [Any ], trace : _TraceDict , samples : int
317- ) -> List [np .ndarray ]:
304+ vars : list [Any ], trace : _TraceDict , samples : int
305+ ) -> list [np .ndarray ]:
318306 with _PosteriorPredictiveSampler (vars , trace , samples , None ) as sampler :
319307 return sampler .draw_values ()
320308
@@ -323,25 +311,25 @@ class _PosteriorPredictiveSampler(AbstractContextManager):
323311 """The process of posterior predictive sampling is quite complicated so this provides a central data store."""
324312
325313 # inputs
326- vars : List [Any ]
314+ vars : list [Any ]
327315 trace : _TraceDict
328316 samples : int
329- size : Optional [ int ] # not supported!
317+ size : int | None # not supported!
330318
331319 # other slots
332320 logger : logging .Logger
333321
334322 # for the search
335- evaluated : Dict [int , np .ndarray ]
336- symbolic_params : List [ Tuple [int , Any ]]
323+ evaluated : dict [int , np .ndarray ]
324+ symbolic_params : list [ tuple [int , Any ]]
337325
338326 # set by make_graph...
339- leaf_nodes : Dict [str , Any ]
340- named_nodes_parents : Dict [str , Any ]
341- named_nodes_children : Dict [str , Any ]
327+ leaf_nodes : dict [str , Any ]
328+ named_nodes_parents : dict [str , Any ]
329+ named_nodes_children : dict [str , Any ]
342330 _tok : contextvars .Token
343331
344- def __init__ (self , vars , trace : _TraceDict , samples , model : Optional [ Model ] , size = None ):
332+ def __init__ (self , vars , trace : _TraceDict , samples , model : Model | None , size = None ):
345333 if size is not None :
346334 raise NotImplementedError (
347335 "sample_posterior_predictive does not support the size argument at this time."
@@ -361,7 +349,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
361349 vectorized_ppc .reset (self ._tok )
362350 return False
363351
364- def draw_values (self ) -> List [np .ndarray ]:
352+ def draw_values (self ) -> list [np .ndarray ]:
365353 vars = self .vars
366354 trace = self .trace
367355 samples = self .samples
@@ -438,8 +426,8 @@ def draw_values(self) -> List[np.ndarray]:
438426 # the below makes sure the graph is evaluated in order
439427 # test_distributions_random::TestDrawValues::test_draw_order fails without it
440428 # The remaining params that must be drawn are all hashable
441- to_eval : Set [int ] = set ()
442- missing_inputs : Set [int ] = {j for j , p in self .symbolic_params }
429+ to_eval : set [int ] = set ()
430+ missing_inputs : set [int ] = {j for j , p in self .symbolic_params }
443431
444432 while to_eval or missing_inputs :
445433 if to_eval == missing_inputs :
@@ -477,19 +465,19 @@ def init(self) -> None:
477465 from the posterior predictive distribution. Notably it initializes the
478466 ``_DrawValuesContext`` bookkeeping object and evaluates the "fast drawable"
479467 parts of the model."""
480- vars : List [Any ] = self .vars
468+ vars : list [Any ] = self .vars
481469 trace : _TraceDict = self .trace
482470 samples : int = self .samples
483- leaf_nodes : Dict [str , Any ]
484- named_nodes_parents : Dict [str , Any ]
485- named_nodes_children : Dict [str , Any ]
471+ leaf_nodes : dict [str , Any ]
472+ named_nodes_parents : dict [str , Any ]
473+ named_nodes_children : dict [str , Any ]
486474
487475 # initialization phase
488476 context = _DrawValuesContext .get_context ()
489477 assert isinstance (context , _DrawValuesContext )
490478 with context :
491479 drawn = context .drawn_vars
492- evaluated : Dict [int , Any ] = {}
480+ evaluated : dict [int , Any ] = {}
493481 symbolic_params = []
494482 for i , var in enumerate (vars ):
495483 if is_fast_drawable (var ):
@@ -534,7 +522,7 @@ def make_graph(self) -> None:
534522 else :
535523 self .named_nodes_children [k ].update (nnc [k ])
536524
537- def draw_value (self , param , trace : Optional [ _TraceDict ] = None , givens = None ):
525+ def draw_value (self , param , trace : _TraceDict | None = None , givens = None ):
538526 """Draw a set of random values from a distribution or return a constant.
539527
540528 Parameters
@@ -559,7 +547,7 @@ def random_sample(
559547 param ,
560548 point : _TraceDict ,
561549 size : int ,
562- shape : Tuple [int , ...],
550+ shape : tuple [int , ...],
563551 ) -> np .ndarray :
564552 val = meth (point = point , size = size )
565553 try :
@@ -591,7 +579,7 @@ def random_sample(
591579 elif hasattr (param , "random" ) and param .random is not None :
592580 model = modelcontext (None )
593581 assert isinstance (model , Model )
594- shape : Tuple [int , ...] = tuple (_param_shape (param , model ))
582+ shape : tuple [int , ...] = tuple (_param_shape (param , model ))
595583 return random_sample (param .random , param , point = trace , size = samples , shape = shape )
596584 elif (
597585 hasattr (param , "distribution" )
@@ -602,7 +590,7 @@ def random_sample(
602590 # shape inspection for ObservedRV
603591 dist_tmp = param .distribution
604592 try :
605- distshape : Tuple [int , ...] = tuple (param .observations .shape .eval ())
593+ distshape : tuple [int , ...] = tuple (param .observations .shape .eval ())
606594 except AttributeError :
607595 distshape = tuple (param .observations .shape )
608596
@@ -689,7 +677,7 @@ def random_sample(
689677 raise ValueError ("Unexpected type in draw_value: %s" % type (param ))
690678
691679
692- def _param_shape (var_desig , model : Model ) -> Tuple [int , ...]:
680+ def _param_shape (var_desig , model : Model ) -> tuple [int , ...]:
693681 if isinstance (var_desig , str ):
694682 v = model [var_desig ]
695683 else :
0 commit comments