|
19 | 19 | import urllib.request
|
20 | 20 |
|
21 | 21 | from copy import copy
|
22 |
| -from typing import Any, Dict, List, Optional, Sequence |
| 22 | +from typing import Any, Dict, List, Optional, Sequence, Union |
23 | 23 |
|
24 | 24 | import aesara
|
25 | 25 | import aesara.tensor as at
|
26 | 26 | import numpy as np
|
27 | 27 | import pandas as pd
|
28 | 28 |
|
| 29 | +from aesara.compile.sharedvalue import SharedVariable |
29 | 30 | from aesara.graph.basic import Apply
|
30 | 31 | from aesara.tensor.type import TensorType
|
31 |
| -from aesara.tensor.var import TensorVariable |
| 32 | +from aesara.tensor.var import TensorConstant, TensorVariable |
32 | 33 |
|
33 | 34 | import pymc as pm
|
34 | 35 |
|
|
40 | 41 | "Minibatch",
|
41 | 42 | "align_minibatches",
|
42 | 43 | "Data",
|
| 44 | + "ConstantData", |
| 45 | + "MutableData", |
43 | 46 | ]
|
44 | 47 | BASE_URL = "https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/data/{filename}"
|
45 | 48 |
|
@@ -502,9 +505,64 @@ def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict
|
502 | 505 | return coords
|
503 | 506 |
|
504 | 507 |
|
505 |
| -class Data: |
506 |
| - """Data container class that wraps :func:`aesara.shared` and lets |
507 |
| - the model be aware of its inputs and outputs. |
| 508 | +def ConstantData( |
| 509 | + name: str, |
| 510 | + value, |
| 511 | + *, |
| 512 | + dims: Optional[Sequence[str]] = None, |
| 513 | + export_index_as_coords=False, |
| 514 | + **kwargs, |
| 515 | +) -> TensorConstant: |
| 516 | + """Alias for ``pm.Data(..., mutable=False)``. |
| 517 | +
|
| 518 | + Registers the ``value`` as a ``TensorConstant`` with the model. |
| 519 | + """ |
| 520 | + return Data( |
| 521 | + name, |
| 522 | + value, |
| 523 | + dims=dims, |
| 524 | + export_index_as_coords=export_index_as_coords, |
| 525 | + mutable=False, |
| 526 | + **kwargs, |
| 527 | + ) |
| 528 | + |
| 529 | + |
| 530 | +def MutableData( |
| 531 | + name: str, |
| 532 | + value, |
| 533 | + *, |
| 534 | + dims: Optional[Sequence[str]] = None, |
| 535 | + export_index_as_coords=False, |
| 536 | + **kwargs, |
| 537 | +) -> SharedVariable: |
| 538 | + """Alias for ``pm.Data(..., mutable=True)``. |
| 539 | +
|
| 540 | + Registers the ``value`` as a ``SharedVariable`` with the model. |
| 541 | + """ |
| 542 | + return Data( |
| 543 | + name, |
| 544 | + value, |
| 545 | + dims=dims, |
| 546 | + export_index_as_coords=export_index_as_coords, |
| 547 | + mutable=True, |
| 548 | + **kwargs, |
| 549 | + ) |
| 550 | + |
| 551 | + |
| 552 | +def Data( |
| 553 | + name: str, |
| 554 | + value, |
| 555 | + *, |
| 556 | + dims: Optional[Sequence[str]] = None, |
| 557 | + export_index_as_coords=False, |
| 558 | + mutable: bool = True, |
| 559 | + **kwargs, |
| 560 | +) -> Union[SharedVariable, TensorConstant]: |
| 561 | + """Data container that registers a data variable with the model. |
| 562 | +
|
| 563 | + Depending on the ``mutable`` setting (default: True), the variable |
| 564 | + is registered as a ``SharedVariable``, enabling it to be altered |
| 565 | + in value and shape, but NOT in dimensionality using ``pm.set_data()``. |
508 | 566 |
|
509 | 567 | Parameters
|
510 | 568 | ----------
|
@@ -552,52 +610,46 @@ class Data:
|
552 | 610 | For more information, take a look at this example notebook
|
553 | 611 | https://docs.pymc.io/notebooks/data_container.html
|
554 | 612 | """
|
| 613 | + if isinstance(value, list): |
| 614 | + value = np.array(value) |
555 | 615 |
|
556 |
| - def __new__( |
557 |
| - self, |
558 |
| - name: str, |
559 |
| - value, |
560 |
| - *, |
561 |
| - dims: Optional[Sequence[str]] = None, |
562 |
| - export_index_as_coords=False, |
563 |
| - **kwargs, |
564 |
| - ): |
565 |
| - if isinstance(value, list): |
566 |
| - value = np.array(value) |
567 |
| - |
568 |
| - # Add data container to the named variables of the model. |
569 |
| - try: |
570 |
| - model = pm.Model.get_context() |
571 |
| - except TypeError: |
572 |
| - raise TypeError( |
573 |
| - "No model on context stack, which is needed to instantiate a data container. " |
574 |
| - "Add variable inside a 'with model:' block." |
575 |
| - ) |
576 |
| - name = model.name_for(name) |
577 |
| - |
578 |
| - # `pandas_to_array` takes care of parameter `value` and |
579 |
| - # transforms it to something digestible for pymc |
580 |
| - shared_object = aesara.shared(pandas_to_array(value), name, **kwargs) |
581 |
| - |
582 |
| - if isinstance(dims, str): |
583 |
| - dims = (dims,) |
584 |
| - if not (dims is None or len(dims) == shared_object.ndim): |
585 |
| - raise pm.exceptions.ShapeError( |
586 |
| - "Length of `dims` must match the dimensions of the dataset.", |
587 |
| - actual=len(dims), |
588 |
| - expected=shared_object.ndim, |
589 |
| - ) |
590 |
| - |
591 |
| - coords = determine_coords(model, value, dims) |
592 |
| - |
593 |
| - if export_index_as_coords: |
594 |
| - model.add_coords(coords) |
595 |
| - elif dims: |
596 |
| - # Register new dimension lengths |
597 |
| - for d, dname in enumerate(dims): |
598 |
| - if not dname in model.dim_lengths: |
599 |
| - model.add_coord(dname, values=None, length=shared_object.shape[d]) |
600 |
| - |
601 |
| - model.add_random_variable(shared_object, dims=dims) |
602 |
| - |
603 |
| - return shared_object |
| 616 | + # Add data container to the named variables of the model. |
| 617 | + try: |
| 618 | + model = pm.Model.get_context() |
| 619 | + except TypeError: |
| 620 | + raise TypeError( |
| 621 | + "No model on context stack, which is needed to instantiate a data container. " |
| 622 | + "Add variable inside a 'with model:' block." |
| 623 | + ) |
| 624 | + name = model.name_for(name) |
| 625 | + |
| 626 | + # `pandas_to_array` takes care of parameter `value` and |
| 627 | + # transforms it to something digestible for Aesara. |
| 628 | + arr = pandas_to_array(value) |
| 629 | + if mutable: |
| 630 | + x = aesara.shared(arr, name, **kwargs) |
| 631 | + else: |
| 632 | + x = at.as_tensor_variable(arr, name, **kwargs) |
| 633 | + |
| 634 | + if isinstance(dims, str): |
| 635 | + dims = (dims,) |
| 636 | + if not (dims is None or len(dims) == x.ndim): |
| 637 | + raise pm.exceptions.ShapeError( |
| 638 | + "Length of `dims` must match the dimensions of the dataset.", |
| 639 | + actual=len(dims), |
| 640 | + expected=x.ndim, |
| 641 | + ) |
| 642 | + |
| 643 | + coords = determine_coords(model, value, dims) |
| 644 | + |
| 645 | + if export_index_as_coords: |
| 646 | + model.add_coords(coords) |
| 647 | + elif dims: |
| 648 | + # Register new dimension lengths |
| 649 | + for d, dname in enumerate(dims): |
| 650 | + if not dname in model.dim_lengths: |
| 651 | + model.add_coord(dname, values=None, length=x.shape[d]) |
| 652 | + |
| 653 | + model.add_random_variable(x, dims=dims) |
| 654 | + |
| 655 | + return x |
0 commit comments