@@ -53,7 +53,7 @@ def test_rai_insights_empty_save_load_save(self):
53
53
# Validate, but this isn't the main check
54
54
validate_rai_insights (
55
55
rai_2 , X_train , X_test ,
56
- LABELS , ModelTask .CLASSIFICATION , None )
56
+ LABELS , ModelTask .CLASSIFICATION , None , None , None )
57
57
58
58
# Save again (this is where Issue #1046 manifested)
59
59
rai_2 .save (save_2 )
@@ -68,7 +68,8 @@ def test_rai_insights_empty_save_load_save(self):
68
68
ManagerNames .COUNTERFACTUAL ])
69
69
def test_rai_insights_save_load_add_save (self , manager_type ):
70
70
data_train , data_test , y_train , y_test , categorical_features , \
71
- continuous_features , target_name , classes = \
71
+ continuous_features , target_name , classes , \
72
+ feature_columns , feature_range_keys = \
72
73
create_adult_income_dataset ()
73
74
X_train = data_train .drop ([target_name ], axis = 1 )
74
75
@@ -120,7 +121,9 @@ def test_rai_insights_save_load_add_save(self, manager_type):
120
121
validate_rai_insights (
121
122
rai_2 , data_train , data_test ,
122
123
target_name , ModelTask .CLASSIFICATION ,
123
- categorical_features = categorical_features )
124
+ categorical_features = categorical_features ,
125
+ feature_range_keys = feature_range_keys ,
126
+ feature_columns = feature_columns )
124
127
125
128
# Save again (this is where Issue #1046 manifested)
126
129
rai_2 .save (save_2 )
@@ -135,7 +138,8 @@ def test_load_missing_dirs(self, target_dir):
135
138
# The exception is the Explainer, which always creates a file
136
139
# in its subdirectory
137
140
data_train , data_test , y_train , y_test , categorical_features , \
138
- continuous_features , target_name , classes = \
141
+ continuous_features , target_name , classes , \
142
+ feature_columns , feature_range_keys = \
139
143
create_adult_income_dataset ()
140
144
X_train = data_train .drop ([target_name ], axis = 1 )
141
145
@@ -201,7 +205,8 @@ def test_loading_rai_insights_without_model_file(self):
201
205
ManagerNames .COUNTERFACTUAL ])
202
206
def test_rai_insights_add_save_load_save (self , manager_type ):
203
207
data_train , data_test , y_train , y_test , categorical_features , \
204
- continuous_features , target_name , classes = \
208
+ continuous_features , target_name , classes , \
209
+ feature_columns , feature_range_keys = \
205
210
create_adult_income_dataset ()
206
211
X_train = data_train .drop ([target_name ], axis = 1 )
207
212
@@ -253,7 +258,9 @@ def test_rai_insights_add_save_load_save(self, manager_type):
253
258
validate_rai_insights (
254
259
rai_2 , data_train , data_test ,
255
260
target_name , ModelTask .CLASSIFICATION ,
256
- categorical_features = categorical_features )
261
+ categorical_features = categorical_features ,
262
+ feature_range_keys = feature_range_keys ,
263
+ feature_columns = feature_columns )
257
264
258
265
# Save again (this is where Issue #1081 manifested)
259
266
rai_2 .save (save_2 )
@@ -265,14 +272,21 @@ def validate_rai_insights(
265
272
test_data ,
266
273
target_column ,
267
274
task_type ,
268
- categorical_features
275
+ categorical_features ,
276
+ feature_range_keys ,
277
+ feature_columns
269
278
):
270
-
271
279
pd .testing .assert_frame_equal (rai_insights .train , train_data )
272
280
pd .testing .assert_frame_equal (rai_insights .test , test_data )
273
281
assert rai_insights .target_column == target_column
274
282
assert rai_insights .task_type == task_type
275
283
assert rai_insights .categorical_features == (categorical_features or [])
284
+ if feature_range_keys is not None :
285
+ assert feature_range_keys .sort () == \
286
+ list (rai_insights ._feature_ranges [0 ].keys ()).sort ()
287
+ if feature_columns is not None :
288
+ assert rai_insights ._feature_columns == (feature_columns or [])
289
+ assert target_column not in rai_insights ._feature_columns
276
290
if task_type == ModelTask .CLASSIFICATION :
277
291
classes = train_data [target_column ].unique ()
278
292
classes .sort ()
0 commit comments