11from functools import lru_cache
2- from typing import Callable , Dict , List , Tuple , Union
2+ from typing import Callable , Dict , List , Tuple , Union , Type
33
44import pandas as pd
55
@@ -39,15 +39,20 @@ def __init__(self, table: Table, by: List[Variable]):
3939 df = table_to_frame (table , include_metas = True )
4040 # observed=True keeps only groups with at leas one instance
4141 self .group_by = df .groupby ([a .name for a in by ], observed = True )
42+ self .by = tuple (by )
4243
4344 # lru_cache that is caches on the object level
4445 self .compute_aggregation = lru_cache ()(self ._compute_aggregation )
4546
47+ AggDescType = Union [str ,
48+ Callable ,
49+ Tuple [str , Union [str , Callable ]],
50+ Tuple [str , Union [str , Callable ], Union [Type [Variable ], bool ]]
51+ ]
52+
4653 def aggregate (
4754 self ,
48- aggregations : Dict [
49- Variable , List [Union [str , Callable , Tuple [str , Union [str , Callable ]]]]
50- ],
55+ aggregations : Dict [Variable , List [AggDescType ]],
5156 callback : Callable = dummy_callback ,
5257 ) -> Table :
5358 """
@@ -57,12 +62,16 @@ def aggregate(
5762 ----------
5863 aggregations
5964 The dictionary that defines aggregations that need to be computed
60- for variables. We support two formats:
65+ for variables. We support three formats:
6166 - {variable name: [agg function 1, agg function 2]}
6267 - {variable name: [(agg name 1, agg function 1), (agg name 1, agg function 1)]}
68+ - {variable name: [(agg name 1, agg function 1, output_variable_type1), ...]}
6369 Where agg name is the aggregation name used in the output column name.
6470 Aggregation function can be either function or string that defines
6571 aggregation in Pandas (e.g. mean).
72+ output_variable_type can be a type for a new variable, True to copy
73+ the input variable, or False to create a new variable of the same type
74+ as the input
6675 callback
6776 Callback function to report the progress
6877
@@ -75,29 +84,44 @@ def aggregate(
7584 count = 0
7685
7786 result_agg = []
87+ output_variables = []
7888 for col , aggs in aggregations .items ():
7989 for agg in aggs :
80- res = self ._compute_aggregation (col , agg )
90+ res , var = self ._compute_aggregation (col , agg )
8191 result_agg .append (res )
92+ output_variables .append (var )
8293 count += 1
8394 callback (count / num_aggs * 0.8 )
8495
85- agg_table = self ._aggregations_to_table (result_agg )
96+ agg_table = self ._aggregations_to_table (result_agg , output_variables )
8697 callback (1 )
8798 return agg_table
8899
89100 def _compute_aggregation (
90- self , col : Variable , agg : Union [str , Callable , Tuple [str , Union [str , Callable ]]]
91- ) -> pd .Series :
101+ self , col : Variable , agg : AggDescType ) -> Tuple [pd .Series , Variable ]:
92102 # use named aggregation to avoid issues with same column names when reset_index
93103 if isinstance (agg , tuple ):
94- name , agg = agg
104+ name , agg , var_type , * _ = ( * agg , None )
95105 else :
96106 name = agg if isinstance (agg , str ) else agg .__name__
107+ var_type = None
97108 col_name = f"{ col .name } - { name } "
98- return self .group_by [col .name ].agg (** {col_name : agg })
99-
100- def _aggregations_to_table (self , aggregations : List [pd .Series ]) -> Table :
109+ agg_col = self .group_by [col .name ].agg (** {col_name : agg })
110+ if var_type is True :
111+ var = col .copy (name = col_name )
112+ elif var_type is False :
113+ var = col .make (name = col_name )
114+ elif var_type is None :
115+ var = None
116+ else :
117+ assert issubclass (var_type , Variable )
118+ var = var_type .make (name = col_name )
119+ return agg_col , var
120+
121+ def _aggregations_to_table (
122+ self ,
123+ aggregations : List [pd .Series ],
124+ output_variables : List [Union [Variable , None ]]) -> Table :
101125 """Concatenate aggregation series and convert back to Table"""
102126 if aggregations :
103127 df = pd .concat (aggregations , axis = 1 )
@@ -107,7 +131,7 @@ def _aggregations_to_table(self, aggregations: List[pd.Series]) -> Table:
107131 df = df .drop (columns = df .columns )
108132 gb_attributes = df .index .names
109133 df = df .reset_index () # move group by var that are in index to columns
110- table = table_from_frame (df )
134+ table = table_from_frame (df , variables = ( * self . by , * output_variables ) )
111135
112136 # group by variables should be last two columns in metas in the output
113137 metas = table .domain .metas
0 commit comments