120
120
from pymc .logprob .utils import CheckParameterValue , check_potential_measurability
121
121
122
122
123
- class RVTransform (abc .ABC ):
123
+ class Transform (abc .ABC ):
124
124
ndim_supp = None
125
125
126
126
@abc .abstractmethod
@@ -174,10 +174,10 @@ class MeasurableTransform(MeasurableElemwise):
174
174
175
175
# Cannot use `transform` as name because it would clash with the property added by
176
176
# the `TransformValuesRewrite`
177
- transform_elemwise : RVTransform
177
+ transform_elemwise : Transform
178
178
measurable_input_idx : int
179
179
180
- def __init__ (self , * args , transform : RVTransform , measurable_input_idx : int , ** kwargs ):
180
+ def __init__ (self , * args , transform : Transform , measurable_input_idx : int , ** kwargs ):
181
181
self .transform_elemwise = transform
182
182
self .measurable_input_idx = measurable_input_idx
183
183
super ().__init__ (* args , ** kwargs )
@@ -444,7 +444,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
444
444
scalar_op = node .op .scalar_op
445
445
measurable_input_idx = 0
446
446
transform_inputs : Tuple [TensorVariable , ...] = (measurable_input ,)
447
- transform : RVTransform
447
+ transform : Transform
448
448
449
449
transform_dict = {
450
450
Exp : ExpTransform (),
@@ -559,7 +559,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
559
559
)
560
560
561
561
562
- class SinhTransform (RVTransform ):
562
+ class SinhTransform (Transform ):
563
563
name = "sinh"
564
564
ndim_supp = 0
565
565
@@ -570,7 +570,7 @@ def backward(self, value, *inputs):
570
570
return pt .arcsinh (value )
571
571
572
572
573
- class CoshTransform (RVTransform ):
573
+ class CoshTransform (Transform ):
574
574
name = "cosh"
575
575
ndim_supp = 0
576
576
@@ -589,7 +589,7 @@ def log_jac_det(self, value, *inputs):
589
589
)
590
590
591
591
592
- class TanhTransform (RVTransform ):
592
+ class TanhTransform (Transform ):
593
593
name = "tanh"
594
594
ndim_supp = 0
595
595
@@ -600,7 +600,7 @@ def backward(self, value, *inputs):
600
600
return pt .arctanh (value )
601
601
602
602
603
- class ArcsinhTransform (RVTransform ):
603
+ class ArcsinhTransform (Transform ):
604
604
name = "arcsinh"
605
605
ndim_supp = 0
606
606
@@ -611,7 +611,7 @@ def backward(self, value, *inputs):
611
611
return pt .sinh (value )
612
612
613
613
614
- class ArccoshTransform (RVTransform ):
614
+ class ArccoshTransform (Transform ):
615
615
name = "arccosh"
616
616
ndim_supp = 0
617
617
@@ -622,7 +622,7 @@ def backward(self, value, *inputs):
622
622
return pt .cosh (value )
623
623
624
624
625
- class ArctanhTransform (RVTransform ):
625
+ class ArctanhTransform (Transform ):
626
626
name = "arctanh"
627
627
ndim_supp = 0
628
628
@@ -633,7 +633,7 @@ def backward(self, value, *inputs):
633
633
return pt .tanh (value )
634
634
635
635
636
- class ErfTransform (RVTransform ):
636
+ class ErfTransform (Transform ):
637
637
name = "erf"
638
638
ndim_supp = 0
639
639
@@ -644,7 +644,7 @@ def backward(self, value, *inputs):
644
644
return pt .erfinv (value )
645
645
646
646
647
- class ErfcTransform (RVTransform ):
647
+ class ErfcTransform (Transform ):
648
648
name = "erfc"
649
649
ndim_supp = 0
650
650
@@ -655,7 +655,7 @@ def backward(self, value, *inputs):
655
655
return pt .erfcinv (value )
656
656
657
657
658
- class ErfcxTransform (RVTransform ):
658
+ class ErfcxTransform (Transform ):
659
659
name = "erfcx"
660
660
ndim_supp = 0
661
661
@@ -681,7 +681,7 @@ def calc_delta_x(value, prior_result):
681
681
return result [- 1 ]
682
682
683
683
684
- class LocTransform (RVTransform ):
684
+ class LocTransform (Transform ):
685
685
name = "loc"
686
686
687
687
def __init__ (self , transform_args_fn ):
@@ -699,7 +699,7 @@ def log_jac_det(self, value, *inputs):
699
699
return pt .zeros_like (value )
700
700
701
701
702
- class ScaleTransform (RVTransform ):
702
+ class ScaleTransform (Transform ):
703
703
name = "scale"
704
704
705
705
def __init__ (self , transform_args_fn ):
@@ -718,7 +718,7 @@ def log_jac_det(self, value, *inputs):
718
718
return - pt .log (pt .abs (pt .broadcast_to (scale , value .shape )))
719
719
720
720
721
- class LogTransform (RVTransform ):
721
+ class LogTransform (Transform ):
722
722
name = "log"
723
723
724
724
def forward (self , value , * inputs ):
@@ -731,7 +731,7 @@ def log_jac_det(self, value, *inputs):
731
731
return value
732
732
733
733
734
- class ExpTransform (RVTransform ):
734
+ class ExpTransform (Transform ):
735
735
name = "exp"
736
736
737
737
def forward (self , value , * inputs ):
@@ -744,7 +744,7 @@ def log_jac_det(self, value, *inputs):
744
744
return - pt .log (value )
745
745
746
746
747
- class AbsTransform (RVTransform ):
747
+ class AbsTransform (Transform ):
748
748
name = "abs"
749
749
750
750
def forward (self , value , * inputs ):
@@ -758,7 +758,7 @@ def log_jac_det(self, value, *inputs):
758
758
return pt .switch (value >= 0 , 0 , np .nan )
759
759
760
760
761
- class PowerTransform (RVTransform ):
761
+ class PowerTransform (Transform ):
762
762
name = "power"
763
763
764
764
def __init__ (self , power = None ):
@@ -801,7 +801,7 @@ def log_jac_det(self, value, *inputs):
801
801
return res
802
802
803
803
804
- class IntervalTransform (RVTransform ):
804
+ class IntervalTransform (Transform ):
805
805
name = "interval"
806
806
807
807
def __init__ (self , args_fn : Callable [..., Tuple [Optional [Variable ], Optional [Variable ]]]):
@@ -909,7 +909,7 @@ def log_jac_det(self, value, *inputs):
909
909
return pt .zeros_like (value )
910
910
911
911
912
- class LogOddsTransform (RVTransform ):
912
+ class LogOddsTransform (Transform ):
913
913
name = "logodds"
914
914
915
915
def backward (self , value , * inputs ):
@@ -923,7 +923,7 @@ def log_jac_det(self, value, *inputs):
923
923
return pt .log (sigmoid_value ) + pt .log1p (- sigmoid_value )
924
924
925
925
926
- class SimplexTransform (RVTransform ):
926
+ class SimplexTransform (Transform ):
927
927
name = "simplex"
928
928
929
929
def forward (self , value , * inputs ):
@@ -950,7 +950,7 @@ def log_jac_det(self, value, *inputs):
950
950
return pt .sum (res , - 1 )
951
951
952
952
953
- class CircularTransform (RVTransform ):
953
+ class CircularTransform (Transform ):
954
954
name = "circular"
955
955
956
956
def backward (self , value , * inputs ):
@@ -963,7 +963,7 @@ def log_jac_det(self, value, *inputs):
963
963
return pt .zeros (value .shape )
964
964
965
965
966
- class ChainedTransform (RVTransform ):
966
+ class ChainedTransform (Transform ):
967
967
name = "chain"
968
968
969
969
def __init__ (self , transform_list , base_op ):
0 commit comments