|
37 | 37 | from pymc.distributions.shape_utils import change_dist_size
|
38 | 38 | from pymc.initial_point import make_initial_point_fn
|
39 | 39 | from pymc.logprob import joint_logp
|
| 40 | +from pymc.logprob.abstract import icdf |
40 | 41 | from pymc.logprob.utils import ParameterValueError
|
41 | 42 | from pymc.pytensorf import (
|
42 | 43 | compile_pymc,
|
@@ -520,6 +521,97 @@ def check_logcdf(
|
520 | 521 | )
|
521 | 522 |
|
522 | 523 |
|
| 524 | +def check_icdf( |
| 525 | + pymc_dist: Distribution, |
| 526 | + paramdomains: Dict[str, Domain], |
| 527 | + scipy_icdf: Callable, |
| 528 | + decimal: Optional[int] = None, |
| 529 | + n_samples: int = 100, |
| 530 | +) -> None: |
| 531 | + """ |
| 532 | + Generic test for PyMC icdf methods |
| 533 | +
|
| 534 | + The following tests are performed by default: |
| 535 | + 1. Test PyMC icdf and equivalent scipy icdf (ppf) methods give similar |
| 536 | + results for parameters inside the supported edges. |
| 537 | + Edges are excluded by default, but can be artificially included by |
| 538 | + creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`) |
| 539 | + 2. Test PyMC icdf method raises for invalid parameter values |
| 540 | + outside the supported edges. |
| 541 | + 3. Test PyMC icdf method returns np.nan for values below 0 or above 1, |
| 542 | + when using valid parameters. |
| 543 | +
|
| 544 | + Parameters |
| 545 | + ---------- |
| 546 | + pymc_dist: PyMC distribution |
| 547 | + paramdomains : Dictionary of Parameter : Domain pairs |
| 548 | + Supported domains of distribution parameters |
| 549 | + scipy_icdf : Scipy icdf method |
| 550 | + Scipy icdf (ppp) method of equivalent pymc_dist distribution |
| 551 | + decimal : int, optional |
| 552 | + Level of precision with which pymc_dist and scipy_icdf are compared. |
| 553 | + Defaults to 6 for float64 and 3 for float32 |
| 554 | + n_samples : int |
| 555 | + Upper limit on the number of valid domain and value combinations that |
| 556 | + are compared between pymc and scipy methods. If n_samples is below the |
| 557 | + total number of combinations, a random subset is evaluated. Setting |
| 558 | + n_samples = -1, will return all possible combinations. Defaults to 100 |
| 559 | +
|
| 560 | + """ |
| 561 | + if decimal is None: |
| 562 | + decimal = select_by_precision(float64=6, float32=3) |
| 563 | + |
| 564 | + dist = create_dist_from_paramdomains(pymc_dist, paramdomains) |
| 565 | + q = pt.scalar(dtype="float64", name="q") |
| 566 | + dist_icdf = icdf(dist, q) |
| 567 | + pymc_icdf = pytensor.function(list(inputvars(dist_icdf)), dist_icdf) |
| 568 | + |
| 569 | + # Test pymc and scipy distributions match for values and parameters |
| 570 | + # within the supported domain edges (excluding edges) |
| 571 | + domains = paramdomains.copy() |
| 572 | + domain = Domain([0, 0.1, 0.5, 0.75, 0.95, 0.99, 1]) # Values we test the icdf at |
| 573 | + domains["q"] = domain |
| 574 | + |
| 575 | + for point in product(domains, n_samples=n_samples): |
| 576 | + point = dict(point) |
| 577 | + npt.assert_almost_equal( |
| 578 | + pymc_icdf(**point), |
| 579 | + scipy_icdf(**point), |
| 580 | + decimal=decimal, |
| 581 | + err_msg=str(point), |
| 582 | + ) |
| 583 | + |
| 584 | + valid_value = domain.vals[0] |
| 585 | + valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()} |
| 586 | + valid_params["q"] = valid_value |
| 587 | + |
| 588 | + # Test pymc distribution raises ParameterValueError for parameters outside the |
| 589 | + # supported domain edges (excluding edges) |
| 590 | + invalid_params = find_invalid_scalar_params(paramdomains) |
| 591 | + for invalid_param, invalid_edges in invalid_params.items(): |
| 592 | + for invalid_edge in invalid_edges: |
| 593 | + if invalid_edge is None: |
| 594 | + continue |
| 595 | + |
| 596 | + point = valid_params.copy() |
| 597 | + point[invalid_param] = invalid_edge |
| 598 | + with pytest.raises(ParameterValueError): |
| 599 | + pymc_icdf(**point) |
| 600 | + pytest.fail(f"test_params={point}") |
| 601 | + |
| 602 | + # Test that values below 0 or above 1 evaluate to nan |
| 603 | + invalid_values = find_invalid_scalar_params({"q": domain})["q"] |
| 604 | + for invalid_value in invalid_values: |
| 605 | + if invalid_value is not None: |
| 606 | + point = valid_params.copy() |
| 607 | + point["q"] = invalid_value |
| 608 | + npt.assert_equal( |
| 609 | + pymc_icdf(**point), |
| 610 | + np.nan, |
| 611 | + err_msg=str(point), |
| 612 | + ) |
| 613 | + |
| 614 | + |
523 | 615 | def check_selfconsistency_discrete_logcdf(
|
524 | 616 | distribution: Distribution,
|
525 | 617 | domain: Domain,
|
|
0 commit comments