@@ -284,9 +284,8 @@ def logp(self):
284
284
"""Compiled log probability density function"""
285
285
return self .model .fn (self .logpt )
286
286
287
- @property
288
- def logp_elemwise (self ):
289
- return self .model .fn (self .logp_elemwiset )
287
+ def logp_elemwise (self , vars = None , jacobian = True ):
288
+ return self .model .fn (self .logp_elemwiset (vars = vars , jacobian = jacobian ))
290
289
291
290
def dlogp (self , vars = None ):
292
291
"""Compiled log probability density gradient function"""
@@ -728,6 +727,66 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
728
727
}
729
728
return ValueGradFunction (costs , grad_vars , extra_vars_and_values , ** kwargs )
730
729
730
+ def logp_elemwiset (
731
+ self ,
732
+ vars : Optional [Union [Variable , List [Variable ]]] = None ,
733
+ jacobian : bool = True ,
734
+ ) -> List [Variable ]:
735
+ """Elemwise log-probability of the model.
736
+
737
+ Parameters
738
+ ----------
739
+ vars: list of random variables or potential terms, optional
740
+ Compute the gradient with respect to those variables. If None, use all
741
+ free and observed random variables, as well as potential terms in model.
742
+ jacobian
743
+ Whether to include jacobian terms in logprob graph. Defaults to True.
744
+
745
+ Returns
746
+ -------
747
+ Elemwise logp terms for ecah requested variable, in the same order of input.
748
+ """
749
+ if vars is None :
750
+ vars = self .free_RVs + self .observed_RVs + self .potentials
751
+ elif not isinstance (vars , (list , tuple )):
752
+ vars = [vars ]
753
+
754
+ # We need to separate random variables from potential terms, and remember their
755
+ # original order so that we can merge them together in the same order at the end
756
+ rv_values = {}
757
+ potentials = []
758
+ rv_order , potential_order = [], []
759
+ for i , var in enumerate (vars ):
760
+ value_var = self .rvs_to_values .get (var )
761
+ if value_var is not None :
762
+ rv_values [var ] = value_var
763
+ rv_order .append (i )
764
+ else :
765
+ if var in self .potentials :
766
+ potentials .append (var )
767
+ potential_order .append (i )
768
+ else :
769
+ raise ValueError (
770
+ f"Requested variable { var } not found among the model variables"
771
+ )
772
+
773
+ rv_logps = []
774
+ if rv_values :
775
+ rv_logps = logpt (list (rv_values .keys ()), rv_values , sum = False , jacobian = jacobian )
776
+ if not isinstance (rv_logps , list ):
777
+ rv_logps = [rv_logps ]
778
+
779
+ # Replace random variables by their value variables in potential terms
780
+ potential_logps = []
781
+ if potentials :
782
+ potential_logps , _ = rvs_to_value_vars (potentials , apply_transforms = True )
783
+
784
+ logp_elemwise = [None ] * len (vars )
785
+ for logp_order , logp in zip ((rv_order + potential_order ), (rv_logps + potential_logps )):
786
+ logp_elemwise [logp_order ] = logp
787
+
788
+ return logp_elemwise
789
+
731
790
@property
732
791
def logpt (self ):
733
792
"""Aesara scalar of log-probability of the model"""
0 commit comments