19
19
import urllib .request
20
20
21
21
from copy import copy
22
- from typing import Any , Dict , List , Sequence
22
+ from typing import Any , Dict , List , Optional , Sequence
23
23
24
24
import aesara
25
25
import aesara .tensor as at
@@ -463,6 +463,45 @@ def align_minibatches(batches=None):
463
463
rng .seed ()
464
464
465
465
466
+ def determine_coords (model , value , dims : Optional [Sequence [str ]] = None ) -> Dict [str , Sequence ]:
467
+ """Determines coordinate values from data or the model (via ``dims``)."""
468
+ coords = {}
469
+
470
+ # If value is a df or a series, we interpret the index as coords:
471
+ if isinstance (value , (pd .Series , pd .DataFrame )):
472
+ dim_name = None
473
+ if dims is not None :
474
+ dim_name = dims [0 ]
475
+ if dim_name is None and value .index .name is not None :
476
+ dim_name = value .index .name
477
+ if dim_name is not None :
478
+ coords [dim_name ] = value .index
479
+
480
+ # If value is a df, we also interpret the columns as coords:
481
+ if isinstance (value , pd .DataFrame ):
482
+ dim_name = None
483
+ if dims is not None :
484
+ dim_name = dims [1 ]
485
+ if dim_name is None and value .columns .name is not None :
486
+ dim_name = value .columns .name
487
+ if dim_name is not None :
488
+ coords [dim_name ] = value .columns
489
+
490
+ if isinstance (value , np .ndarray ) and dims is not None :
491
+ if len (dims ) != value .ndim :
492
+ raise pm .exceptions .ShapeError (
493
+ "Invalid data shape. The rank of the dataset must match the " "length of `dims`." ,
494
+ actual = value .shape ,
495
+ expected = value .ndim ,
496
+ )
497
+ for size , dim in zip (value .shape , dims ):
498
+ coord = model .coords .get (dim , None )
499
+ if coord is None :
500
+ coords [dim ] = pd .RangeIndex (size , name = dim )
501
+
502
+ return coords
503
+
504
+
466
505
class Data :
467
506
"""Data container class that wraps :func:`aesara.shared` and lets
468
507
the model be aware of its inputs and outputs.
@@ -516,10 +555,10 @@ class Data:
516
555
517
556
def __new__ (
518
557
self ,
519
- name ,
558
+ name : str ,
520
559
value ,
521
560
* ,
522
- dims = None ,
561
+ dims : Optional [ Sequence [ str ]] = None ,
523
562
export_index_as_coords = False ,
524
563
** kwargs ,
525
564
):
@@ -549,7 +588,7 @@ def __new__(
549
588
expected = shared_object .ndim ,
550
589
)
551
590
552
- coords = self . set_coords (model , value , dims )
591
+ coords = determine_coords (model , value , dims )
553
592
554
593
if export_index_as_coords :
555
594
model .add_coords (coords )
@@ -559,58 +598,6 @@ def __new__(
559
598
if not dname in model .dim_lengths :
560
599
model .add_coord (dname , values = None , length = shared_object .shape [d ])
561
600
562
- # To draw the node for this variable in the graphviz Digraph we need
563
- # its shape.
564
- # XXX: This needs to be refactored
565
- # shared_object.dshape = tuple(shared_object.shape.eval())
566
- # if dims is not None:
567
- # shape_dims = model.shape_from_dims(dims)
568
- # if shared_object.dshape != shape_dims:
569
- # raise pm.exceptions.ShapeError(
570
- # "Data shape does not match with specified `dims`.",
571
- # actual=shared_object.dshape,
572
- # expected=shape_dims,
573
- # )
574
-
575
601
model .add_random_variable (shared_object , dims = dims )
576
602
577
603
return shared_object
578
-
579
- @staticmethod
580
- def set_coords (model , value , dims = None ) -> Dict [str , Sequence ]:
581
- coords = {}
582
-
583
- # If value is a df or a series, we interpret the index as coords:
584
- if isinstance (value , (pd .Series , pd .DataFrame )):
585
- dim_name = None
586
- if dims is not None :
587
- dim_name = dims [0 ]
588
- if dim_name is None and value .index .name is not None :
589
- dim_name = value .index .name
590
- if dim_name is not None :
591
- coords [dim_name ] = value .index
592
-
593
- # If value is a df, we also interpret the columns as coords:
594
- if isinstance (value , pd .DataFrame ):
595
- dim_name = None
596
- if dims is not None :
597
- dim_name = dims [1 ]
598
- if dim_name is None and value .columns .name is not None :
599
- dim_name = value .columns .name
600
- if dim_name is not None :
601
- coords [dim_name ] = value .columns
602
-
603
- if isinstance (value , np .ndarray ) and dims is not None :
604
- if len (dims ) != value .ndim :
605
- raise pm .exceptions .ShapeError (
606
- "Invalid data shape. The rank of the dataset must match the "
607
- "length of `dims`." ,
608
- actual = value .shape ,
609
- expected = value .ndim ,
610
- )
611
- for size , dim in zip (value .shape , dims ):
612
- coord = model .coords .get (dim , None )
613
- if coord is None :
614
- coords [dim ] = pd .RangeIndex (size , name = dim )
615
-
616
- return coords
0 commit comments