21
21
from monty .json import MontyEncoder
22
22
from abipy .tools .serialization import pmg_serialize
23
23
from abipy .tools .iotools import make_executable
24
+ from abipy .core .structure import Structure
24
25
from abipy .core .mixins import NotebookWriter
25
26
from abipy .tools .numtools import sort_and_groupby
26
27
from abipy .tools import duck
@@ -333,6 +334,30 @@ def from_flow(cls, flow, outdirs="all", nids=None, ext=None, task_class=None) ->
333
334
334
335
return robot
335
336
337
+
338
+ def __len__ (self ):
339
+ return len (self ._abifiles )
340
+
341
+ #def __iter__(self):
342
+ # return iter(self._abifiles)
343
+
344
+ def __getitem__ (self , key ):
345
+ # self[key]
346
+ return self ._abifiles .__getitem__ (key )
347
+
348
+ def __enter__ (self ):
349
+ return self
350
+
351
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
352
+ """Activated at the end of the with statement."""
353
+ self .close ()
354
+
355
+ def keys (self ):
356
+ return self ._abifiles .keys ()
357
+
358
+ def items (self ):
359
+ return self ._abifiles .items ()
360
+
336
361
def add_extfile_of_node (self , node , nids = None , task_class = None ) -> None :
337
362
"""
338
363
Add the file produced by this node to the robot.
@@ -547,32 +572,6 @@ def exceptions(self) -> list:
547
572
"""List of exceptions."""
548
573
return self ._exceptions
549
574
550
- def __len__ (self ):
551
- return len (self ._abifiles )
552
-
553
- #def __iter__(self):
554
- # return iter(self._abifiles)
555
-
556
- #def __contains__(self, item):
557
- # return item in self._abifiles
558
-
559
- def __getitem__ (self , key ):
560
- # self[key]
561
- return self ._abifiles .__getitem__ (key )
562
-
563
- def __enter__ (self ):
564
- return self
565
-
566
- def __exit__ (self , exc_type , exc_val , exc_tb ):
567
- """Activated at the end of the with statement."""
568
- self .close ()
569
-
570
- def keys (self ):
571
- return self ._abifiles .keys ()
572
-
573
- def items (self ):
574
- return self ._abifiles .items ()
575
-
576
575
@property
577
576
def labels (self ) -> list [str ]:
578
577
"""
@@ -614,6 +613,24 @@ def _repr_html_(self) -> str:
614
613
"""Integration with jupyter_ notebooks."""
615
614
return '<ol start="0">\n {}\n </ol>' .format ("\n " .join ("<li>%s</li>" % label for label , abifile in self .items ()))
616
615
616
+ def getattr_alleq (self , aname : str ):
617
+ """
618
+ Return the value of attribute aname.
619
+ Raises ValueError if value is not the same across all the files in the robot.
620
+ """
621
+ val1 = getattr (self .abifiles [0 ], aname )
622
+
623
+ for abifile in self .abifiles [1 :]:
624
+ val2 = getattr (abifile , aname )
625
+ if isinstance (val1 , (str , int , float )):
626
+ eq = val1 == val2
627
+ elif isinstance (val1 , np .ndarray ):
628
+ eq = np .allclose (val1 , val2 )
629
+ if not eq :
630
+ raise ValueError (f"Different values of { aname = } , { val1 = } , { val2 = } " )
631
+
632
+ return val1
633
+
617
634
@property
618
635
def abifiles (self ) -> list :
619
636
"""List of netcdf files."""
@@ -639,22 +656,45 @@ def has_different_structures(self, rtol=1e-05, atol=1e-08) -> str:
639
656
640
657
return "\n " .join (lines )
641
658
642
- #def apply(self, func_or_string, args=(), **kwargs):
643
- # """
644
- # Applies function to all ``abifiles`` available in the robot.
659
+ def _get_ref_abifile_from_basename (self , ref_basename : str | None ):
660
+ """
661
+ Find reference abifile. If None, the first file in the robot is used.
662
+ """
663
+ ref_file = self .abifiles [0 ]
664
+ if ref_basename is None :
665
+ return ref_file
645
666
646
- # Args:
647
- # func_or_string: If callable, the output of func_or_string(abifile, ...) is used.
648
- # If string, the output of getattr(abifile, func_or_string)(...)
649
- # args (tuple): Positional arguments to pass to function in addition to the array/series
650
- # kwargs: Additional keyword arguments will be passed as keywords to the function
667
+ for i , abifile in enumerate (self .abifiles ):
668
+ if abifile .basename == ref_basename :
669
+ return abifile
651
670
652
- # Return: List of results
653
- # """
654
- # if callable(func_or_string):
655
- # return [func_or_string(abifile, *args, *kwargs) for abifile in self.abifiles]
656
- # else:
657
- # return [duck.getattrd(abifile, func_or_string)(*args, **kwargs) for abifile in self.abifiles]
671
+ raise ValueError (f"Cannot find { ref_basename = } " )
672
+
673
+ @staticmethod
674
+ def _compare_attr_name (aname : str , ref_abifile , other_abifile ) -> None :
675
+ """
676
+ Compare the value of attribute `aname` in two files.
677
+ """
678
+ # Get attributes in abifile first, then in abifile.r, else raise.
679
+ if hasattr (ref_abifile , aname ):
680
+ val1 , val2 = getattr (ref_abifile , aname ), getattr (other_abifile , aname )
681
+
682
+ elif hasattr (ref_abifile , "r" ) and hasattr (ref_abifile .r , aname ):
683
+ val1 , val2 = getattr (ref_abifile .r , aname ), getattr (other_abifile .r , aname )
684
+
685
+ else :
686
+ raise AttributeError (f"Cannot find attribute `{ aname = } `" )
687
+
688
+ # Now compare val1 and val2 taking into account the type.
689
+ if isinstance (val1 , (str , int , float , Structure )):
690
+ eq = val1 == val2
691
+ elif isinstance (val1 , np .ndarray ):
692
+ eq = np .allclose (val1 , val2 )
693
+ else :
694
+ raise TypeError (f"Don't know how to handle comparison for type: { type (val1 )} " )
695
+
696
+ if not eq :
697
+ raise ValueError (f"Different values of { aname = } , { val1 = } , { val2 = } " )
658
698
659
699
def is_sortable (self , aname : str , raise_exc : bool = False ) -> bool :
660
700
"""
@@ -812,16 +852,6 @@ def close(self) -> None:
812
852
print ("Exception while closing: " , abifile .filepath )
813
853
print (exc )
814
854
815
- #def get_attributes(self, attr_name, obj=None, retdict=False):
816
- # od = OrderedDict()
817
- # for label, abifile in self.items():
818
- # obj = abifile if obj is None else getattr(abifile, obj)
819
- # od[label] = getattr(obj, attr_name)
820
- # if retdict:
821
- # return od
822
- # else:
823
- # return list(od.values())
824
-
825
855
def _exec_funcs (self , funcs , arg ) -> dict :
826
856
"""
827
857
Execute list of callable functions. Each function receives arg as argument.
@@ -1236,7 +1266,7 @@ def plot_abs_conv(ax1, ax2, xs, yvals, abs_conv, xlabel, fontsize, hatch, **kwar
1236
1266
"""
1237
1267
y_xmax = yvals [- 1 ]
1238
1268
span_style = dict (alpha = 0.2 , color = "green" , hatch = hatch )
1239
- ax1 .axhspan (y_xmax - abs_conv , y_xmax + abs_conv , label = r"$|y-y(x_{max})} | \leq %s$" % abs_conv , ** span_style )
1269
+ ax1 .axhspan (y_xmax - abs_conv , y_xmax + abs_conv , label = r"$|y-y(x_{max})| \leq %s$" % abs_conv , ** span_style )
1240
1270
1241
1271
# Plot |y - y_xmax| in log scale on ax2.
1242
1272
ax2 .plot (xs , np .abs (yvals - y_xmax ), ** kwargs )
0 commit comments