@@ -816,41 +816,97 @@ def __init__(self, args_fn: Callable[..., Tuple[Optional[Variable], Optional[Var
816
816
"""
817
817
self .args_fn = args_fn
818
818
819
- def forward (self , value , * inputs ):
819
+ def get_a_and_b (self , inputs ):
820
+ """Return interval bound values.
821
+
822
+ Also returns two boolean variables indicating whether the transform is known to be statically bounded.
823
+ This is used to generate smaller graphs in the transform methods.
824
+ """
820
825
a , b = self .args_fn (* inputs )
826
+ lower_bounded , upper_bounded = True , True
827
+ if a is None :
828
+ a = - pt .inf
829
+ lower_bounded = False
830
+ if b is None :
831
+ b = pt .inf
832
+ upper_bounded = False
833
+ return a , b , lower_bounded , upper_bounded
821
834
822
- if a is not None and b is not None :
823
- return pt .log (value - a ) - pt .log (b - value )
824
- elif a is not None :
825
- return pt .log (value - a )
826
- elif b is not None :
827
- return pt .log (b - value )
835
+ def forward (self , value , * inputs ):
836
+ a , b , lower_bounded , upper_bounded = self .get_a_and_b (inputs )
837
+
838
+ log_lower_distance = pt .log (value - a )
839
+ log_upper_distance = pt .log (b - value )
840
+
841
+ if lower_bounded and upper_bounded :
842
+ return pt .where (
843
+ pt .and_ (pt .neq (a , - pt .inf ), pt .neq (b , pt .inf )),
844
+ log_lower_distance - log_upper_distance ,
845
+ pt .where (
846
+ pt .neq (a , - pt .inf ),
847
+ log_lower_distance ,
848
+ pt .where (
849
+ pt .neq (b , pt .inf ),
850
+ log_upper_distance ,
851
+ value ,
852
+ ),
853
+ ),
854
+ )
855
+ elif lower_bounded :
856
+ return log_lower_distance
857
+ elif upper_bounded :
858
+ return log_upper_distance
828
859
else :
829
- raise ValueError ( "Both edges of IntervalTransform cannot be None" )
860
+ return value
830
861
831
862
def backward (self , value , * inputs ):
832
- a , b = self .args_fn (* inputs )
833
-
834
- if a is not None and b is not None :
835
- sigmoid_x = pt .sigmoid (value )
836
- return sigmoid_x * b + (1 - sigmoid_x ) * a
837
- elif a is not None :
838
- return pt .exp (value ) + a
839
- elif b is not None :
840
- return b - pt .exp (value )
863
+ a , b , lower_bounded , upper_bounded = self .get_a_and_b (inputs )
864
+
865
+ exp_value = pt .exp (value )
866
+ sigmoid_x = pt .sigmoid (value )
867
+ lower_distance = exp_value + a
868
+ upper_distance = b - exp_value
869
+
870
+ if lower_bounded and upper_bounded :
871
+ return pt .where (
872
+ pt .and_ (pt .neq (a , - pt .inf ), pt .neq (b , pt .inf )),
873
+ sigmoid_x * b + (1 - sigmoid_x ) * a ,
874
+ pt .where (
875
+ pt .neq (a , - pt .inf ),
876
+ lower_distance ,
877
+ pt .where (
878
+ pt .neq (b , pt .inf ),
879
+ upper_distance ,
880
+ value ,
881
+ ),
882
+ ),
883
+ )
884
+ elif lower_bounded :
885
+ return lower_distance
886
+ elif upper_bounded :
887
+ return upper_distance
841
888
else :
842
- raise ValueError ( "Both edges of IntervalTransform cannot be None" )
889
+ return value
843
890
844
891
def log_jac_det (self , value , * inputs ):
845
- a , b = self .args_fn ( * inputs )
892
+ a , b , lower_bounded , upper_bounded = self .get_a_and_b ( inputs )
846
893
847
- if a is not None and b is not None :
894
+ if lower_bounded and upper_bounded :
848
895
s = pt .softplus (- value )
849
- return pt .log (b - a ) - 2 * s - value
850
- elif a is None and b is None :
851
- raise ValueError ("Both edges of IntervalTransform cannot be None" )
852
- else :
896
+
897
+ return pt .where (
898
+ pt .and_ (pt .neq (a , - pt .inf ), pt .neq (b , pt .inf )),
899
+ pt .log (b - a ) - 2 * s - value ,
900
+ pt .where (
901
+ pt .or_ (pt .neq (a , - pt .inf ), pt .neq (b , pt .inf )),
902
+ value ,
903
+ pt .zeros_like (value ),
904
+ ),
905
+ )
906
+ elif lower_bounded or upper_bounded :
853
907
return value
908
+ else :
909
+ return pt .zeros_like (value )
854
910
855
911
856
912
class LogOddsTransform (RVTransform ):
0 commit comments