Skip to content

Commit eb9987a

Browse files
committed
on-demand loading of system_info and metric_stats
1 parent fc40c06 commit eb9987a

File tree

10 files changed

+136
-155
lines changed

10 files changed

+136
-155
lines changed

backend/src/impl/benchmark_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def load_sys_infos(config: BenchmarkConfig) -> list[dict]:
8383

8484
systems = systems_return.systems
8585
for system in systems:
86-
temp = system.system_info.to_dict()
86+
temp = system.get_system_info().to_dict()
8787
# Don't include systems with no dataset
8888
if temp["dataset_name"] is not None:
8989
temp["creator"] = system.creator.split("@")[0]

backend/src/impl/db_utils/system_db_utils.py

Lines changed: 22 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
import traceback
77
from datetime import datetime
8-
from typing import Any
8+
from typing import Any, NamedTuple
99

1010
from bson import ObjectId
1111
from explainaboard import DatalabLoaderOption, FileType, Source, get_processor
@@ -16,28 +16,22 @@
1616
from explainaboard_web.impl.db_utils.dataset_db_utils import DatasetDBUtils
1717
from explainaboard_web.impl.db_utils.db_utils import DBUtils
1818
from explainaboard_web.impl.db_utils.user_db_utils import UserDBUtils
19+
from explainaboard_web.impl.internal_models.system_model import SystemModel
1920
from explainaboard_web.impl.storage import get_storage
20-
from explainaboard_web.impl.utils import (
21-
abort_with_error_message,
22-
binarize_bson,
23-
unbinarize_bson,
24-
)
21+
from explainaboard_web.impl.utils import abort_with_error_message, binarize_bson
2522
from explainaboard_web.models import (
2623
AnalysisCase,
2724
System,
28-
SystemInfo,
2925
SystemMetadata,
3026
SystemMetadataUpdatable,
3127
SystemOutput,
3228
SystemOutputProps,
33-
SystemsReturn,
3429
)
3530
from pymongo.client_session import ClientSession
3631

3732

3833
class SystemDBUtils:
3934

40-
_EMAIL_RE = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
4135
_COLON_RE = r"^([A-Za-z0-9_-]+): (.+)$"
4236
_SYSTEM_OUTPUT_CONST = "__SYSOUT__"
4337

@@ -81,65 +75,13 @@ def _parse_system_details_in_doc(
8175
metadata.system_details = parsed
8276
document["system_details"] = parsed
8377

84-
@staticmethod
85-
def system_from_dict(
86-
dikt: dict[str, Any], include_metric_stats: bool = False
87-
) -> System:
88-
document: dict[str, Any] = dikt.copy()
89-
if document.get("_id"):
90-
document["system_id"] = str(document.pop("_id"))
91-
if document.get("is_private") is None:
92-
document["is_private"] = True
93-
94-
# Parse the shared users
95-
shared_users = document.get("shared_users", None)
96-
if shared_users is None or len(shared_users) == 0:
97-
document.pop("shared_users", None)
98-
else:
99-
for user in shared_users:
100-
if not re.fullmatch(SystemDBUtils._EMAIL_RE, user):
101-
abort_with_error_message(
102-
400, f"invalid email address for shared user {user}"
103-
)
104-
105-
metric_stats = []
106-
if "metric_stats" in document:
107-
metric_stats = document["metric_stats"]
108-
document["metric_stats"] = []
109-
110-
# FIXME(lyuyang): The following for loop is added to work around an issue
111-
# related to default values of models. Previously, the generated models
112-
# don't enforce required attributes. This function exploits that loophole.
113-
# Now that we have fixed that loophole, this function needs some major
114-
# refactoring. None was assigned for these fields before implicitly. Now
115-
# we assign them explicitly so this hack does not change the current
116-
# behavior.
117-
for required_field in (
118-
"created_at",
119-
"last_modified",
120-
"system_id",
121-
"system_info",
122-
"metric_stats",
123-
):
124-
if required_field not in document:
125-
document[required_field] = None
126-
127-
system = System.from_dict(document)
128-
if include_metric_stats:
129-
# Unbinarize to numpy array and set explicitly
130-
system.metric_stats = [
131-
[unbinarize_bson(y) for y in x] for x in metric_stats
132-
]
133-
return system
134-
13578
@staticmethod
13679
def query_systems(
13780
query: list | dict,
13881
page: int,
13982
page_size: int,
14083
sort: list | None = None,
141-
include_metric_stats: bool = False,
142-
):
84+
) -> FindSystemsReturn:
14385

14486
permissions_list = [{"is_private": False}]
14587
if get_user().is_authenticated:
@@ -163,18 +105,14 @@ def query_systems(
163105
users = UserDBUtils.find_users(list(ids))
164106
id_to_preferred_username = {user.id: user.preferred_username for user in users}
165107

166-
systems: list[System] = []
167-
if len(documents) == 0:
168-
return SystemsReturn(systems, 0)
108+
systems: list[SystemModel] = []
169109

170110
for doc in documents:
171111
doc["preferred_username"] = id_to_preferred_username[doc["creator"]]
172-
system = SystemDBUtils.system_from_dict(
173-
doc, include_metric_stats=include_metric_stats
174-
)
112+
system = SystemModel.from_dict(doc)
175113
systems.append(system)
176114

177-
return SystemsReturn(systems, total)
115+
return FindSystemsReturn(systems, total)
178116

179117
@staticmethod
180118
def find_systems(
@@ -191,9 +129,8 @@ def find_systems(
191129
sort: list | None = None,
192130
creator: str | None = None,
193131
shared_users: list[str] | None = None,
194-
include_metric_stats: bool = False,
195132
dataset_list: list[tuple[str, str, str]] | None = None,
196-
) -> SystemsReturn:
133+
) -> FindSystemsReturn:
197134
"""find multiple systems that matches the filters"""
198135

199136
search_conditions: list[dict[str, Any]] = []
@@ -230,18 +167,14 @@ def find_systems(
230167
]
231168
search_conditions.append({"$or": dataset_dicts})
232169

233-
systems_return = SystemDBUtils.query_systems(
234-
search_conditions,
235-
page,
236-
page_size,
237-
sort,
238-
include_metric_stats,
170+
systems, total = SystemDBUtils.query_systems(
171+
search_conditions, page, page_size, sort
239172
)
240173
if ids and not sort:
241174
# preserve id order if no `sort` is provided
242175
orders = {sys_id: i for i, sys_id in enumerate(ids)}
243-
systems_return.systems.sort(key=lambda sys: orders[sys.system_id])
244-
return systems_return
176+
systems.sort(key=lambda sys: orders[sys.system_id])
177+
return FindSystemsReturn(systems, total)
245178

246179
@staticmethod
247180
def _load_sys_output(
@@ -398,7 +331,7 @@ def _validate_and_create_system():
398331
system.update(metadata.to_dict())
399332
# -- parse the system details if getting a string from the frontend
400333
SystemDBUtils._parse_system_details_in_doc(system, metadata)
401-
return SystemDBUtils.system_from_dict(system)
334+
return SystemModel.from_dict(system)
402335

403336
system = _validate_and_create_system()
404337

@@ -429,15 +362,13 @@ def _validate_and_create_system():
429362
overall_statistics = SystemDBUtils._process(
430363
system, metadata, system_output_data, custom_features, custom_analyses
431364
)
432-
binarized_stats = [
365+
binarized_metric_stats = [
433366
[binarize_bson(y.get_data()) for y in x]
434367
for x in overall_statistics.metric_stats
435368
]
436369

437370
# -- add the analysis results to the system object
438371
sys_info = overall_statistics.sys_info
439-
system.system_info = SystemInfo.from_dict(sys_info.to_dict())
440-
system.metric_stats = binarized_stats
441372

442373
if sys_info.analysis_levels and sys_info.results.overall:
443374
# store overall metrics in the DB so they can be queried
@@ -456,6 +387,8 @@ def db_operations(session: ClientSession) -> str:
456387
document = general_to_dict(system)
457388
document.pop("system_id")
458389
document.pop("preferred_username")
390+
document["system_info"] = sys_info.to_dict()
391+
document["metric_stats"] = binarized_metric_stats
459392
system_id = DBUtils.insert_one(
460393
DBUtils.DEV_SYSTEM_METADATA, document, session=session
461394
)
@@ -508,9 +441,6 @@ def db_operations(session: ClientSession) -> str:
508441
system_id = DBUtils.execute_transaction(db_operations)
509442
system.system_id = system_id
510443

511-
# -- replace things that can't be returned through JSON for now
512-
system.metric_stats = []
513-
514444
# -- return the system
515445
return system
516446
except ValueError as e:
@@ -553,7 +483,7 @@ def find_system_by_id(system_id: str):
553483
500, "system creator ID not found in DB, please contact the sysadmins"
554484
)
555485
sys_doc["preferred_username"] = user.preferred_username
556-
system = SystemDBUtils.system_from_dict(sys_doc)
486+
system = SystemModel.from_dict(sys_doc)
557487
return system
558488

559489
@staticmethod
@@ -623,3 +553,8 @@ def db_operations(session: ClientSession) -> bool:
623553
return True
624554

625555
return DBUtils.execute_transaction(db_operations)
556+
557+
558+
class FindSystemsReturn(NamedTuple):
559+
systems: list[SystemModel]
560+
total: int

backend/src/impl/default_controllers_impl.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def systems_get(
292292

293293
dir = ASCENDING if sort_direction == "asc" else DESCENDING
294294

295-
return SystemDBUtils.find_systems(
295+
systems, total = SystemDBUtils.find_systems(
296296
page=page,
297297
page_size=page_size,
298298
system_name=system_name,
@@ -304,6 +304,7 @@ def systems_get(
304304
creator=creator,
305305
shared_users=shared_users,
306306
)
307+
return SystemsReturn(systems, total)
307308

308309

309310
def systems_post(body: SystemCreateProps) -> System:
@@ -402,33 +403,29 @@ def systems_analyses_post(body: SystemsAnalysesBody):
402403
system_ids: list = system_ids_str.split(",")
403404
page = 0
404405
page_size = len(system_ids)
405-
systems: list[System] = SystemDBUtils.find_systems(
406-
ids=system_ids,
407-
page=page,
408-
page_size=page_size,
409-
include_metric_stats=True,
406+
systems = SystemDBUtils.find_systems(
407+
ids=system_ids, page=page, page_size=page_size
410408
).systems
411-
systems_len = len(systems)
412-
if systems_len == 0:
409+
if len(systems) == 0:
413410
return SystemAnalysesReturn(system_analyses)
414411

415412
# performance significance test if there are two systems
416413
sig_info = []
417414
if len(systems) == 2:
418415

419-
system1_info: SystemInfo = systems[0].system_info
416+
system1_info: SystemInfo = systems[0].get_system_info()
420417
system1_info_dict = general_to_dict(system1_info)
421418
system1_output_info = SysOutputInfo.from_dict(system1_info_dict)
422419

423420
system1_metric_stats: list[SimpleMetricStats] = [
424-
SimpleMetricStats(stat) for stat in systems[0].metric_stats[0]
421+
SimpleMetricStats(stat) for stat in systems[0].get_metric_stats()[0]
425422
]
426423

427-
system2_info: SystemInfo = systems[1].system_info
424+
system2_info: SystemInfo = systems[1].get_system_info()
428425
system2_info_dict = general_to_dict(system2_info)
429426
system2_output_info = SysOutputInfo.from_dict(system2_info_dict)
430427
system2_metric_stats: list[SimpleMetricStats] = [
431-
SimpleMetricStats(stat) for stat in systems[1].metric_stats[0]
428+
SimpleMetricStats(stat) for stat in systems[1].get_metric_stats()[0]
432429
]
433430

434431
sig_info = pairwise_significance_test(
@@ -439,7 +436,7 @@ def systems_analyses_post(body: SystemsAnalysesBody):
439436
)
440437

441438
for system in systems:
442-
system_info: SystemInfo = system.system_info
439+
system_info = system.get_system_info()
443440
system_info_dict = general_to_dict(system_info)
444441
system_output_info = SysOutputInfo.from_dict(system_info_dict)
445442

@@ -464,12 +461,14 @@ def systems_analyses_post(body: SystemsAnalysesBody):
464461
)
465462

466463
processor = get_processor(TaskType(system_output_info.task_name))
467-
metric_stats = [[SimpleMetricStats(y) for y in x] for x in system.metric_stats]
464+
metric_stats = [
465+
[SimpleMetricStats(y) for y in x] for x in system.get_metric_stats()
466+
]
468467

469468
# Get analysis cases
470469
analysis_cases = []
471470
case_ids = None
472-
for i, analysis_level in enumerate(system.system_info.analysis_levels):
471+
for analysis_level in system.get_system_info().analysis_levels:
473472
level_cases = SystemDBUtils.find_analysis_cases(
474473
system_id=system.system_id,
475474
level=analysis_level.name,
@@ -486,7 +485,10 @@ def systems_analyses_post(body: SystemsAnalysesBody):
486485
metric_stats,
487486
skip_failed_analyses=True,
488487
)
489-
single_analysis = SingleAnalysis(analysis_results=processor_result)
488+
single_analysis = SingleAnalysis(
489+
system_info=system_info,
490+
analysis_results=processor_result,
491+
)
490492
system_analyses.append(single_analysis)
491493

492494
return SystemAnalysesReturn(system_analyses, sig_info)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from __future__ import annotations
2+
3+
import re
4+
from typing import Any
5+
6+
from explainaboard_web.impl.db_utils.db_utils import DBUtils
7+
from explainaboard_web.impl.utils import abort_with_error_message, unbinarize_bson
8+
from explainaboard_web.models.system import System
9+
from explainaboard_web.models.system_info import SystemInfo
10+
11+
12+
class SystemModel(System):
13+
"""Same as System but implements several helper functions that retrieves
14+
additional information and persists data to the DB.
15+
"""
16+
17+
@classmethod
18+
def from_dict(cls, dikt: dict) -> SystemModel:
19+
"""Validates and initialize a SystemModel object from a dict"""
20+
document: dict[str, Any] = dikt.copy()
21+
if document.get("_id"):
22+
document["system_id"] = str(document.pop("_id"))
23+
24+
# Parse the shared users
25+
shared_users = document.get("shared_users", None)
26+
if shared_users is None or len(shared_users) == 0:
27+
document.pop("shared_users", None)
28+
else:
29+
for user in shared_users:
30+
if not re.fullmatch(
31+
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", user
32+
):
33+
abort_with_error_message(
34+
400, f"invalid email address for shared user {user}"
35+
)
36+
37+
# FIXME(lyuyang): The following for loop is added to work around an issue
38+
# related to default values of models. Previously, the generated models
39+
# don't enforce required attributes. This function exploits that loophole.
40+
# Now that we have fixed that loophole, this function needs some major
41+
# refactoring. None was assigned for these fields before implicitly. Now
42+
# we assign them explicitly so this hack does not change the current
43+
# behavior.
44+
for required_field in ("created_at", "last_modified", "system_id"):
45+
if required_field not in document:
46+
document[required_field] = None
47+
48+
return super().from_dict(document)
49+
50+
def get_system_info(self) -> SystemInfo:
51+
"""retrieves system info from DB"""
52+
sys_doc = DBUtils.find_one_by_id(DBUtils.DEV_SYSTEM_METADATA, self.system_id)
53+
if not sys_doc:
54+
abort_with_error_message(404, f"system id: {self.system_id} not found")
55+
return SystemInfo.from_dict(sys_doc["system_info"])
56+
57+
def get_metric_stats(self) -> list[list[float]]:
58+
"""retrieves metric stats from DB"""
59+
sys_doc = DBUtils.find_one_by_id(DBUtils.DEV_SYSTEM_METADATA, self.system_id)
60+
if not sys_doc:
61+
abort_with_error_message(404, f"system id: {self.system_id} not found")
62+
return [[unbinarize_bson(y) for y in x] for x in sys_doc["metric_stats"]]

0 commit comments

Comments
 (0)