5
5
import re
6
6
import traceback
7
7
from datetime import datetime
8
- from typing import Any
8
+ from typing import Any , NamedTuple
9
9
10
10
from bson import ObjectId
11
11
from explainaboard import DatalabLoaderOption , FileType , Source , get_processor
16
16
from explainaboard_web .impl .db_utils .dataset_db_utils import DatasetDBUtils
17
17
from explainaboard_web .impl .db_utils .db_utils import DBUtils
18
18
from explainaboard_web .impl .db_utils .user_db_utils import UserDBUtils
19
+ from explainaboard_web .impl .internal_models .system_model import SystemModel
19
20
from explainaboard_web .impl .storage import get_storage
20
- from explainaboard_web .impl .utils import (
21
- abort_with_error_message ,
22
- binarize_bson ,
23
- unbinarize_bson ,
24
- )
21
+ from explainaboard_web .impl .utils import abort_with_error_message , binarize_bson
25
22
from explainaboard_web .models import (
26
23
AnalysisCase ,
27
24
System ,
28
- SystemInfo ,
29
25
SystemMetadata ,
30
26
SystemMetadataUpdatable ,
31
27
SystemOutput ,
32
28
SystemOutputProps ,
33
- SystemsReturn ,
34
29
)
35
30
from pymongo .client_session import ClientSession
36
31
37
32
38
33
class SystemDBUtils :
39
34
40
- _EMAIL_RE = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
41
35
_COLON_RE = r"^([A-Za-z0-9_-]+): (.+)$"
42
36
_SYSTEM_OUTPUT_CONST = "__SYSOUT__"
43
37
@@ -81,65 +75,13 @@ def _parse_system_details_in_doc(
81
75
metadata .system_details = parsed
82
76
document ["system_details" ] = parsed
83
77
84
- @staticmethod
85
- def system_from_dict (
86
- dikt : dict [str , Any ], include_metric_stats : bool = False
87
- ) -> System :
88
- document : dict [str , Any ] = dikt .copy ()
89
- if document .get ("_id" ):
90
- document ["system_id" ] = str (document .pop ("_id" ))
91
- if document .get ("is_private" ) is None :
92
- document ["is_private" ] = True
93
-
94
- # Parse the shared users
95
- shared_users = document .get ("shared_users" , None )
96
- if shared_users is None or len (shared_users ) == 0 :
97
- document .pop ("shared_users" , None )
98
- else :
99
- for user in shared_users :
100
- if not re .fullmatch (SystemDBUtils ._EMAIL_RE , user ):
101
- abort_with_error_message (
102
- 400 , f"invalid email address for shared user { user } "
103
- )
104
-
105
- metric_stats = []
106
- if "metric_stats" in document :
107
- metric_stats = document ["metric_stats" ]
108
- document ["metric_stats" ] = []
109
-
110
- # FIXME(lyuyang): The following for loop is added to work around an issue
111
- # related to default values of models. Previously, the generated models
112
- # don't enforce required attributes. This function exploits that loophole.
113
- # Now that we have fixed that loophole, this function needs some major
114
- # refactoring. None was assigned for these fields before implicitly. Now
115
- # we assign them explicitly so this hack does not change the current
116
- # behavior.
117
- for required_field in (
118
- "created_at" ,
119
- "last_modified" ,
120
- "system_id" ,
121
- "system_info" ,
122
- "metric_stats" ,
123
- ):
124
- if required_field not in document :
125
- document [required_field ] = None
126
-
127
- system = System .from_dict (document )
128
- if include_metric_stats :
129
- # Unbinarize to numpy array and set explicitly
130
- system .metric_stats = [
131
- [unbinarize_bson (y ) for y in x ] for x in metric_stats
132
- ]
133
- return system
134
-
135
78
@staticmethod
136
79
def query_systems (
137
80
query : list | dict ,
138
81
page : int ,
139
82
page_size : int ,
140
83
sort : list | None = None ,
141
- include_metric_stats : bool = False ,
142
- ):
84
+ ) -> FindSystemsReturn :
143
85
144
86
permissions_list = [{"is_private" : False }]
145
87
if get_user ().is_authenticated :
@@ -163,18 +105,14 @@ def query_systems(
163
105
users = UserDBUtils .find_users (list (ids ))
164
106
id_to_preferred_username = {user .id : user .preferred_username for user in users }
165
107
166
- systems : list [System ] = []
167
- if len (documents ) == 0 :
168
- return SystemsReturn (systems , 0 )
108
+ systems : list [SystemModel ] = []
169
109
170
110
for doc in documents :
171
111
doc ["preferred_username" ] = id_to_preferred_username [doc ["creator" ]]
172
- system = SystemDBUtils .system_from_dict (
173
- doc , include_metric_stats = include_metric_stats
174
- )
112
+ system = SystemModel .from_dict (doc )
175
113
systems .append (system )
176
114
177
- return SystemsReturn (systems , total )
115
+ return FindSystemsReturn (systems , total )
178
116
179
117
@staticmethod
180
118
def find_systems (
@@ -191,9 +129,8 @@ def find_systems(
191
129
sort : list | None = None ,
192
130
creator : str | None = None ,
193
131
shared_users : list [str ] | None = None ,
194
- include_metric_stats : bool = False ,
195
132
dataset_list : list [tuple [str , str , str ]] | None = None ,
196
- ) -> SystemsReturn :
133
+ ) -> FindSystemsReturn :
197
134
"""find multiple systems that matches the filters"""
198
135
199
136
search_conditions : list [dict [str , Any ]] = []
@@ -230,18 +167,14 @@ def find_systems(
230
167
]
231
168
search_conditions .append ({"$or" : dataset_dicts })
232
169
233
- systems_return = SystemDBUtils .query_systems (
234
- search_conditions ,
235
- page ,
236
- page_size ,
237
- sort ,
238
- include_metric_stats ,
170
+ systems , total = SystemDBUtils .query_systems (
171
+ search_conditions , page , page_size , sort
239
172
)
240
173
if ids and not sort :
241
174
# preserve id order if no `sort` is provided
242
175
orders = {sys_id : i for i , sys_id in enumerate (ids )}
243
- systems_return . systems .sort (key = lambda sys : orders [sys .system_id ])
244
- return systems_return
176
+ systems .sort (key = lambda sys : orders [sys .system_id ])
177
+ return FindSystemsReturn ( systems , total )
245
178
246
179
@staticmethod
247
180
def _load_sys_output (
@@ -398,7 +331,7 @@ def _validate_and_create_system():
398
331
system .update (metadata .to_dict ())
399
332
# -- parse the system details if getting a string from the frontend
400
333
SystemDBUtils ._parse_system_details_in_doc (system , metadata )
401
- return SystemDBUtils . system_from_dict (system )
334
+ return SystemModel . from_dict (system )
402
335
403
336
system = _validate_and_create_system ()
404
337
@@ -429,15 +362,13 @@ def _validate_and_create_system():
429
362
overall_statistics = SystemDBUtils ._process (
430
363
system , metadata , system_output_data , custom_features , custom_analyses
431
364
)
432
- binarized_stats = [
365
+ binarized_metric_stats = [
433
366
[binarize_bson (y .get_data ()) for y in x ]
434
367
for x in overall_statistics .metric_stats
435
368
]
436
369
437
370
# -- add the analysis results to the system object
438
371
sys_info = overall_statistics .sys_info
439
- system .system_info = SystemInfo .from_dict (sys_info .to_dict ())
440
- system .metric_stats = binarized_stats
441
372
442
373
if sys_info .analysis_levels and sys_info .results .overall :
443
374
# store overall metrics in the DB so they can be queried
@@ -456,6 +387,8 @@ def db_operations(session: ClientSession) -> str:
456
387
document = general_to_dict (system )
457
388
document .pop ("system_id" )
458
389
document .pop ("preferred_username" )
390
+ document ["system_info" ] = sys_info .to_dict ()
391
+ document ["metric_stats" ] = binarized_metric_stats
459
392
system_id = DBUtils .insert_one (
460
393
DBUtils .DEV_SYSTEM_METADATA , document , session = session
461
394
)
@@ -508,9 +441,6 @@ def db_operations(session: ClientSession) -> str:
508
441
system_id = DBUtils .execute_transaction (db_operations )
509
442
system .system_id = system_id
510
443
511
- # -- replace things that can't be returned through JSON for now
512
- system .metric_stats = []
513
-
514
444
# -- return the system
515
445
return system
516
446
except ValueError as e :
@@ -553,7 +483,7 @@ def find_system_by_id(system_id: str):
553
483
500 , "system creator ID not found in DB, please contact the sysadmins"
554
484
)
555
485
sys_doc ["preferred_username" ] = user .preferred_username
556
- system = SystemDBUtils . system_from_dict (sys_doc )
486
+ system = SystemModel . from_dict (sys_doc )
557
487
return system
558
488
559
489
@staticmethod
@@ -623,3 +553,8 @@ def db_operations(session: ClientSession) -> bool:
623
553
return True
624
554
625
555
return DBUtils .execute_transaction (db_operations )
556
+
557
+
558
+ class FindSystemsReturn (NamedTuple ):
559
+ systems : list [SystemModel ]
560
+ total : int
0 commit comments