Skip to content

Commit 04a03b5

Browse files
committed
Rename RVTransform to Transform
1 parent 1d7c957 commit 04a03b5

File tree

11 files changed

+58
-54
lines changed

11 files changed

+58
-54
lines changed

pymc/distributions/transforms.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
IntervalTransform,
3232
LogOddsTransform,
3333
LogTransform,
34-
RVTransform,
3534
SimplexTransform,
35+
Transform,
3636
)
3737

3838
__all__ = [
39-
"RVTransform",
39+
"Transform",
4040
"simplex",
4141
"logodds",
4242
"Interval",
@@ -60,6 +60,10 @@ def __getattr__(name):
6060
warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning)
6161
return sum_to_1
6262

63+
if name == "RVTransform":
64+
warnings.warn("RVTransform has been renamed to Transform", FutureWarning)
65+
return Transform
66+
6367
raise AttributeError(f"module {__name__} has no attribute {name}")
6468

6569

@@ -69,7 +73,7 @@ def _default_transform(op: Op, rv: TensorVariable):
6973
return None
7074

7175

72-
class LogExpM1(RVTransform):
76+
class LogExpM1(Transform):
7377
name = "log_exp_m1"
7478

7579
def backward(self, value, *inputs):
@@ -87,7 +91,7 @@ def log_jac_det(self, value, *inputs):
8791
return -pt.softplus(-value)
8892

8993

90-
class Ordered(RVTransform):
94+
class Ordered(Transform):
9195
name = "ordered"
9296

9397
def __init__(self, ndim_supp=None):
@@ -110,7 +114,7 @@ def log_jac_det(self, value, *inputs):
110114
return pt.sum(value[..., 1:], axis=-1)
111115

112116

113-
class SumTo1(RVTransform):
117+
class SumTo1(Transform):
114118
"""
115119
Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1]
116120
This Transformation operates on the last dimension of the input tensor.
@@ -134,7 +138,7 @@ def log_jac_det(self, value, *inputs):
134138
return pt.sum(y, axis=-1)
135139

136140

137-
class CholeskyCovPacked(RVTransform):
141+
class CholeskyCovPacked(Transform):
138142
"""
139143
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
140144
log scale
@@ -162,7 +166,7 @@ def log_jac_det(self, value, *inputs):
162166
return pt.sum(value[..., self.diag_idxs], axis=-1)
163167

164168

165-
class Chain(RVTransform):
169+
class Chain(Transform):
166170
__slots__ = ("param_extract_fn", "transform_list", "name")
167171

168172
def __init__(self, transform_list):
@@ -297,7 +301,7 @@ def bounds_fn(*rv_inputs):
297301
super().__init__(args_fn=bounds_fn)
298302

299303

300-
class ZeroSumTransform(RVTransform):
304+
class ZeroSumTransform(Transform):
301305
"""
302306
Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``.
303307

pymc/initial_point.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytensor.graph.fg import FunctionGraph
2525
from pytensor.tensor.variable import TensorVariable
2626

27-
from pymc.logprob.transforms import RVTransform
27+
from pymc.logprob.transforms import Transform
2828
from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs
2929
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name
3030

@@ -177,7 +177,7 @@ def inner(seed, *args, **kwargs):
177177
def make_initial_point_expression(
178178
*,
179179
free_rvs: Sequence[TensorVariable],
180-
rvs_to_transforms: Dict[TensorVariable, RVTransform],
180+
rvs_to_transforms: Dict[TensorVariable, Transform],
181181
initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]],
182182
jitter_rvs: Set[TensorVariable] = None,
183183
default_strategy: str = "moment",

pymc/logprob/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
)
6565
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
6666
from pymc.logprob.transform_value import TransformValuesRewrite
67-
from pymc.logprob.transforms import RVTransform
67+
from pymc.logprob.transforms import Transform
6868
from pymc.logprob.utils import find_rvs_in_graph, rvs_to_value_vars
6969

7070
TensorLike: TypeAlias = Union[Variable, float, np.ndarray]
@@ -589,7 +589,7 @@ def transformed_conditional_logp(
589589
rvs: Sequence[TensorVariable],
590590
*,
591591
rvs_to_values: Dict[TensorVariable, TensorVariable],
592-
rvs_to_transforms: Dict[TensorVariable, RVTransform],
592+
rvs_to_transforms: Dict[TensorVariable, Transform],
593593
jacobian: bool = True,
594594
**kwargs,
595595
) -> List[TensorVariable]:

pymc/logprob/transform_value.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from pymc.logprob.abstract import MeasurableVariable, _logprob
2929
from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db
30-
from pymc.logprob.transforms import RVTransform
30+
from pymc.logprob.transforms import Transform
3131

3232

3333
class TransformedValue(Op):
@@ -67,7 +67,7 @@ class TransformedValueRV(Op):
6767

6868
__props__ = ("transforms",)
6969

70-
def __init__(self, transforms: Sequence[RVTransform]):
70+
def __init__(self, transforms: Sequence[Transform]):
7171
self.transforms = tuple(transforms)
7272
super().__init__()
7373

@@ -320,7 +320,7 @@ class TransformValuesRewrite(GraphRewriter):
320320

321321
def __init__(
322322
self,
323-
values_to_transforms: Dict[TensorVariable, Union[RVTransform, None]],
323+
values_to_transforms: Dict[TensorVariable, Union[Transform, None]],
324324
):
325325
"""
326326
Parameters

pymc/logprob/transforms.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability
121121

122122

123-
class RVTransform(abc.ABC):
123+
class Transform(abc.ABC):
124124
ndim_supp = None
125125

126126
@abc.abstractmethod
@@ -174,10 +174,10 @@ class MeasurableTransform(MeasurableElemwise):
174174

175175
# Cannot use `transform` as name because it would clash with the property added by
176176
# the `TransformValuesRewrite`
177-
transform_elemwise: RVTransform
177+
transform_elemwise: Transform
178178
measurable_input_idx: int
179179

180-
def __init__(self, *args, transform: RVTransform, measurable_input_idx: int, **kwargs):
180+
def __init__(self, *args, transform: Transform, measurable_input_idx: int, **kwargs):
181181
self.transform_elemwise = transform
182182
self.measurable_input_idx = measurable_input_idx
183183
super().__init__(*args, **kwargs)
@@ -444,7 +444,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
444444
scalar_op = node.op.scalar_op
445445
measurable_input_idx = 0
446446
transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,)
447-
transform: RVTransform
447+
transform: Transform
448448

449449
transform_dict = {
450450
Exp: ExpTransform(),
@@ -559,7 +559,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
559559
)
560560

561561

562-
class SinhTransform(RVTransform):
562+
class SinhTransform(Transform):
563563
name = "sinh"
564564
ndim_supp = 0
565565

@@ -570,7 +570,7 @@ def backward(self, value, *inputs):
570570
return pt.arcsinh(value)
571571

572572

573-
class CoshTransform(RVTransform):
573+
class CoshTransform(Transform):
574574
name = "cosh"
575575
ndim_supp = 0
576576

@@ -589,7 +589,7 @@ def log_jac_det(self, value, *inputs):
589589
)
590590

591591

592-
class TanhTransform(RVTransform):
592+
class TanhTransform(Transform):
593593
name = "tanh"
594594
ndim_supp = 0
595595

@@ -600,7 +600,7 @@ def backward(self, value, *inputs):
600600
return pt.arctanh(value)
601601

602602

603-
class ArcsinhTransform(RVTransform):
603+
class ArcsinhTransform(Transform):
604604
name = "arcsinh"
605605
ndim_supp = 0
606606

@@ -611,7 +611,7 @@ def backward(self, value, *inputs):
611611
return pt.sinh(value)
612612

613613

614-
class ArccoshTransform(RVTransform):
614+
class ArccoshTransform(Transform):
615615
name = "arccosh"
616616
ndim_supp = 0
617617

@@ -622,7 +622,7 @@ def backward(self, value, *inputs):
622622
return pt.cosh(value)
623623

624624

625-
class ArctanhTransform(RVTransform):
625+
class ArctanhTransform(Transform):
626626
name = "arctanh"
627627
ndim_supp = 0
628628

@@ -633,7 +633,7 @@ def backward(self, value, *inputs):
633633
return pt.tanh(value)
634634

635635

636-
class ErfTransform(RVTransform):
636+
class ErfTransform(Transform):
637637
name = "erf"
638638
ndim_supp = 0
639639

@@ -644,7 +644,7 @@ def backward(self, value, *inputs):
644644
return pt.erfinv(value)
645645

646646

647-
class ErfcTransform(RVTransform):
647+
class ErfcTransform(Transform):
648648
name = "erfc"
649649
ndim_supp = 0
650650

@@ -655,7 +655,7 @@ def backward(self, value, *inputs):
655655
return pt.erfcinv(value)
656656

657657

658-
class ErfcxTransform(RVTransform):
658+
class ErfcxTransform(Transform):
659659
name = "erfcx"
660660
ndim_supp = 0
661661

@@ -681,7 +681,7 @@ def calc_delta_x(value, prior_result):
681681
return result[-1]
682682

683683

684-
class LocTransform(RVTransform):
684+
class LocTransform(Transform):
685685
name = "loc"
686686

687687
def __init__(self, transform_args_fn):
@@ -699,7 +699,7 @@ def log_jac_det(self, value, *inputs):
699699
return pt.zeros_like(value)
700700

701701

702-
class ScaleTransform(RVTransform):
702+
class ScaleTransform(Transform):
703703
name = "scale"
704704

705705
def __init__(self, transform_args_fn):
@@ -718,7 +718,7 @@ def log_jac_det(self, value, *inputs):
718718
return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape)))
719719

720720

721-
class LogTransform(RVTransform):
721+
class LogTransform(Transform):
722722
name = "log"
723723

724724
def forward(self, value, *inputs):
@@ -731,7 +731,7 @@ def log_jac_det(self, value, *inputs):
731731
return value
732732

733733

734-
class ExpTransform(RVTransform):
734+
class ExpTransform(Transform):
735735
name = "exp"
736736

737737
def forward(self, value, *inputs):
@@ -744,7 +744,7 @@ def log_jac_det(self, value, *inputs):
744744
return -pt.log(value)
745745

746746

747-
class AbsTransform(RVTransform):
747+
class AbsTransform(Transform):
748748
name = "abs"
749749

750750
def forward(self, value, *inputs):
@@ -758,7 +758,7 @@ def log_jac_det(self, value, *inputs):
758758
return pt.switch(value >= 0, 0, np.nan)
759759

760760

761-
class PowerTransform(RVTransform):
761+
class PowerTransform(Transform):
762762
name = "power"
763763

764764
def __init__(self, power=None):
@@ -801,7 +801,7 @@ def log_jac_det(self, value, *inputs):
801801
return res
802802

803803

804-
class IntervalTransform(RVTransform):
804+
class IntervalTransform(Transform):
805805
name = "interval"
806806

807807
def __init__(self, args_fn: Callable[..., Tuple[Optional[Variable], Optional[Variable]]]):
@@ -909,7 +909,7 @@ def log_jac_det(self, value, *inputs):
909909
return pt.zeros_like(value)
910910

911911

912-
class LogOddsTransform(RVTransform):
912+
class LogOddsTransform(Transform):
913913
name = "logodds"
914914

915915
def backward(self, value, *inputs):
@@ -923,7 +923,7 @@ def log_jac_det(self, value, *inputs):
923923
return pt.log(sigmoid_value) + pt.log1p(-sigmoid_value)
924924

925925

926-
class SimplexTransform(RVTransform):
926+
class SimplexTransform(Transform):
927927
name = "simplex"
928928

929929
def forward(self, value, *inputs):
@@ -950,7 +950,7 @@ def log_jac_det(self, value, *inputs):
950950
return pt.sum(res, -1)
951951

952952

953-
class CircularTransform(RVTransform):
953+
class CircularTransform(Transform):
954954
name = "circular"
955955

956956
def backward(self, value, *inputs):
@@ -963,7 +963,7 @@ def log_jac_det(self, value, *inputs):
963963
return pt.zeros(value.shape)
964964

965965

966-
class ChainedTransform(RVTransform):
966+
class ChainedTransform(Transform):
967967
name = "chain"
968968

969969
def __init__(self, transform_list, base_op):

pymc/model/fgraph.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytensor.tensor.elemwise import Elemwise
2525
from pytensor.tensor.sharedvar import ScalarSharedVariable
2626

27-
from pymc.logprob.transforms import RVTransform
27+
from pymc.logprob.transforms import Transform
2828
from pymc.model.core import Model
2929
from pymc.pytensorf import StringType, find_rng_nodes, toposort_replace
3030

@@ -59,8 +59,8 @@ def perform(self, *args, **kwargs):
5959
class ModelValuedVar(ModelVar):
6060
__props__ = ("transform",)
6161

62-
def __init__(self, transform: Optional[RVTransform] = None):
63-
if transform is not None and not isinstance(transform, RVTransform):
62+
def __init__(self, transform: Optional[Transform] = None):
63+
if transform is not None and not isinstance(transform, Transform):
6464
raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}")
6565
self.transform = transform
6666
super().__init__()

pymc/model/transform/conditioning.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytensor.tensor.random.op import RandomVariable
2424

2525
from pymc import Model
26-
from pymc.logprob.transforms import RVTransform
26+
from pymc.logprob.transforms import Transform
2727
from pymc.model.fgraph import (
2828
ModelDeterministic,
2929
ModelFreeRV,
@@ -263,7 +263,7 @@ def do(
263263

264264
def change_value_transforms(
265265
model: Model,
266-
vars_to_transforms: Mapping[ModelVariable, Union[RVTransform, None]],
266+
vars_to_transforms: Mapping[ModelVariable, Union[Transform, None]],
267267
) -> Model:
268268
"""Change the value variables transforms in the model
269269

0 commit comments

Comments
 (0)