Skip to content

Commit a9b8afe

Browse files
michaelosthegetwiecki
authored andcommitted
Extract staticmethod into a function
1 parent b6f76e5 commit a9b8afe

File tree

1 file changed

+43
-56
lines changed

1 file changed

+43
-56
lines changed

pymc/data.py

+43-56
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import urllib.request
2020

2121
from copy import copy
22-
from typing import Any, Dict, List, Sequence
22+
from typing import Any, Dict, List, Optional, Sequence
2323

2424
import aesara
2525
import aesara.tensor as at
@@ -463,6 +463,45 @@ def align_minibatches(batches=None):
463463
rng.seed()
464464

465465

466+
def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict[str, Sequence]:
467+
"""Determines coordinate values from data or the model (via ``dims``)."""
468+
coords = {}
469+
470+
# If value is a df or a series, we interpret the index as coords:
471+
if isinstance(value, (pd.Series, pd.DataFrame)):
472+
dim_name = None
473+
if dims is not None:
474+
dim_name = dims[0]
475+
if dim_name is None and value.index.name is not None:
476+
dim_name = value.index.name
477+
if dim_name is not None:
478+
coords[dim_name] = value.index
479+
480+
# If value is a df, we also interpret the columns as coords:
481+
if isinstance(value, pd.DataFrame):
482+
dim_name = None
483+
if dims is not None:
484+
dim_name = dims[1]
485+
if dim_name is None and value.columns.name is not None:
486+
dim_name = value.columns.name
487+
if dim_name is not None:
488+
coords[dim_name] = value.columns
489+
490+
if isinstance(value, np.ndarray) and dims is not None:
491+
if len(dims) != value.ndim:
492+
raise pm.exceptions.ShapeError(
493+
"Invalid data shape. The rank of the dataset must match the " "length of `dims`.",
494+
actual=value.shape,
495+
expected=value.ndim,
496+
)
497+
for size, dim in zip(value.shape, dims):
498+
coord = model.coords.get(dim, None)
499+
if coord is None:
500+
coords[dim] = pd.RangeIndex(size, name=dim)
501+
502+
return coords
503+
504+
466505
class Data:
467506
"""Data container class that wraps :func:`aesara.shared` and lets
468507
the model be aware of its inputs and outputs.
@@ -516,10 +555,10 @@ class Data:
516555

517556
def __new__(
518557
self,
519-
name,
558+
name: str,
520559
value,
521560
*,
522-
dims=None,
561+
dims: Optional[Sequence[str]] = None,
523562
export_index_as_coords=False,
524563
**kwargs,
525564
):
@@ -549,7 +588,7 @@ def __new__(
549588
expected=shared_object.ndim,
550589
)
551590

552-
coords = self.set_coords(model, value, dims)
591+
coords = determine_coords(model, value, dims)
553592

554593
if export_index_as_coords:
555594
model.add_coords(coords)
@@ -559,58 +598,6 @@ def __new__(
559598
if not dname in model.dim_lengths:
560599
model.add_coord(dname, values=None, length=shared_object.shape[d])
561600

562-
# To draw the node for this variable in the graphviz Digraph we need
563-
# its shape.
564-
# XXX: This needs to be refactored
565-
# shared_object.dshape = tuple(shared_object.shape.eval())
566-
# if dims is not None:
567-
# shape_dims = model.shape_from_dims(dims)
568-
# if shared_object.dshape != shape_dims:
569-
# raise pm.exceptions.ShapeError(
570-
# "Data shape does not match with specified `dims`.",
571-
# actual=shared_object.dshape,
572-
# expected=shape_dims,
573-
# )
574-
575601
model.add_random_variable(shared_object, dims=dims)
576602

577603
return shared_object
578-
579-
@staticmethod
580-
def set_coords(model, value, dims=None) -> Dict[str, Sequence]:
581-
coords = {}
582-
583-
# If value is a df or a series, we interpret the index as coords:
584-
if isinstance(value, (pd.Series, pd.DataFrame)):
585-
dim_name = None
586-
if dims is not None:
587-
dim_name = dims[0]
588-
if dim_name is None and value.index.name is not None:
589-
dim_name = value.index.name
590-
if dim_name is not None:
591-
coords[dim_name] = value.index
592-
593-
# If value is a df, we also interpret the columns as coords:
594-
if isinstance(value, pd.DataFrame):
595-
dim_name = None
596-
if dims is not None:
597-
dim_name = dims[1]
598-
if dim_name is None and value.columns.name is not None:
599-
dim_name = value.columns.name
600-
if dim_name is not None:
601-
coords[dim_name] = value.columns
602-
603-
if isinstance(value, np.ndarray) and dims is not None:
604-
if len(dims) != value.ndim:
605-
raise pm.exceptions.ShapeError(
606-
"Invalid data shape. The rank of the dataset must match the "
607-
"length of `dims`.",
608-
actual=value.shape,
609-
expected=value.ndim,
610-
)
611-
for size, dim in zip(value.shape, dims):
612-
coord = model.coords.get(dim, None)
613-
if coord is None:
614-
coords[dim] = pd.RangeIndex(size, name=dim)
615-
616-
return coords

0 commit comments

Comments
 (0)