Skip to content

Commit 67038ad

Browse files
committed
implement update_overall_statistics() and refactor create_system()
1 parent eb9987a commit 67038ad

File tree

3 files changed

+218
-164
lines changed

3 files changed

+218
-164
lines changed

backend/src/impl/db_utils/db_utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
from collections.abc import Callable
24
from dataclasses import dataclass
3-
from typing import Optional, TypeVar
5+
from typing import TypeVar
46

57
from bson.objectid import InvalidId, ObjectId
68
from explainaboard_web.impl.db import get_db
@@ -88,7 +90,10 @@ def insert_many(
8890

8991
@staticmethod
9092
def find_one_by_id(
91-
collection: DBCollection, docid: str, projection: Optional[dict] = None
93+
collection: DBCollection,
94+
docid: str,
95+
projection: dict | None = None,
96+
session: ClientSession | None = None,
9297
):
9398
"""
9499
Find and return a document with the _id field
@@ -102,11 +107,16 @@ def find_one_by_id(
102107
"""Mongo accepts custom ID"""
103108
_id = docid
104109
finally:
105-
return DBUtils.get_collection(collection).find_one({"_id": _id}, projection)
110+
return DBUtils.get_collection(collection).find_one(
111+
{"_id": _id}, projection, session=session
112+
)
106113

107114
@staticmethod
108115
def update_one_by_id(
109-
collection: DBCollection, docid: str, field_to_value: dict
116+
collection: DBCollection,
117+
docid: str,
118+
field_to_value: dict,
119+
session: ClientSession | None = None,
110120
) -> bool:
111121
"""
112122
Update a document with the _id field
@@ -117,7 +127,7 @@ def update_one_by_id(
117127
"""
118128
try:
119129
result: UpdateResult = DBUtils.get_collection(collection).update_one(
120-
{"_id": ObjectId(docid)}, {"$set": field_to_value}
130+
{"_id": ObjectId(docid)}, {"$set": field_to_value}, session=session
121131
)
122132
if int(result.modified_count) == 1:
123133
return True
@@ -179,11 +189,11 @@ def count(collection: DBCollection, filt: dict = None) -> int:
179189
@staticmethod
180190
def find(
181191
collection: DBCollection,
182-
filt: Optional[dict] = None,
183-
sort: Optional[list] = None,
192+
filt: dict | None = None,
193+
sort: list | None = None,
184194
skip=0,
185195
limit: int = 10,
186-
projection: Optional[dict] = None,
196+
projection: dict | None = None,
187197
) -> tuple[Cursor, int]:
188198
"""
189199
Find multiple documents

backend/src/impl/db_utils/system_db_utils.py

Lines changed: 19 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,15 @@
88
from typing import Any, NamedTuple
99

1010
from bson import ObjectId
11-
from explainaboard import DatalabLoaderOption, FileType, Source, get_processor
12-
from explainaboard.loaders.file_loader import FileLoaderReturn
11+
from explainaboard import DatalabLoaderOption, FileType, Source
1312
from explainaboard.loaders.loader_registry import get_loader_class
14-
from explainaboard.serialization.legacy import general_to_dict
1513
from explainaboard_web.impl.auth import get_user
1614
from explainaboard_web.impl.db_utils.dataset_db_utils import DatasetDBUtils
1715
from explainaboard_web.impl.db_utils.db_utils import DBUtils
1816
from explainaboard_web.impl.db_utils.user_db_utils import UserDBUtils
1917
from explainaboard_web.impl.internal_models.system_model import SystemModel
2018
from explainaboard_web.impl.storage import get_storage
21-
from explainaboard_web.impl.utils import abort_with_error_message, binarize_bson
19+
from explainaboard_web.impl.utils import abort_with_error_message
2220
from explainaboard_web.models import (
2321
AnalysisCase,
2422
System,
@@ -33,7 +31,6 @@
3331
class SystemDBUtils:
3432

3533
_COLON_RE = r"^([A-Za-z0-9_-]+): (.+)$"
36-
_SYSTEM_OUTPUT_CONST = "__SYSOUT__"
3734

3835
@staticmethod
3936
def _parse_colon_line(line) -> tuple[str, str]:
@@ -178,17 +175,15 @@ def find_systems(
178175

179176
@staticmethod
180177
def _load_sys_output(
181-
system: System,
182-
metadata: SystemMetadata,
178+
system: SystemModel,
183179
system_output: SystemOutputProps,
184180
custom_dataset: SystemOutputProps | None,
185-
dataset_custom_features: dict,
186181
):
187182
"""
188183
Load the system output from the uploaded file
189184
"""
190185
if custom_dataset:
191-
return get_loader_class(task=metadata.task)(
186+
return get_loader_class(task=system.task)(
192187
dataset_data=custom_dataset.data,
193188
output_data=system_output.data,
194189
dataset_source=Source.in_memory,
@@ -198,13 +193,13 @@ def _load_sys_output(
198193
).load()
199194
elif system.dataset:
200195
return (
201-
get_loader_class(task=metadata.task)
196+
get_loader_class(task=system.task)
202197
.from_datalab(
203198
dataset=DatalabLoaderOption(
204199
system.dataset.dataset_name,
205200
system.dataset.sub_dataset,
206-
metadata.dataset_split,
207-
custom_features=dataset_custom_features,
201+
system.dataset.split,
202+
custom_features=system.get_dataset_custom_features(),
208203
),
209204
output_data=system_output.data,
210205
output_file_type=FileType(system_output.file_type),
@@ -214,41 +209,6 @@ def _load_sys_output(
214209
)
215210
raise ValueError("neither dataset or custom_dataset is available")
216211

217-
@staticmethod
218-
def _process(
219-
system: System,
220-
metadata: SystemMetadata,
221-
system_output_data: FileLoaderReturn,
222-
custom_features: dict,
223-
custom_analyses: list,
224-
):
225-
processor = get_processor(metadata.task)
226-
metrics_lookup = {
227-
metric.name: metric
228-
for metric in get_processor(metadata.task).full_metric_list()
229-
}
230-
metric_configs = []
231-
for metric_name in metadata.metric_names:
232-
if metric_name not in metrics_lookup:
233-
abort_with_error_message(
234-
400, f"{metric_name} is not a supported metric"
235-
)
236-
metric_configs.append(metrics_lookup[metric_name])
237-
processor_metadata = {
238-
**metadata.to_dict(),
239-
"dataset_name": system.dataset.dataset_name if system.dataset else None,
240-
"sub_dataset_name": system.dataset.sub_dataset if system.dataset else None,
241-
"dataset_split": metadata.dataset_split,
242-
"task_name": metadata.task,
243-
"metric_configs": metric_configs,
244-
"custom_features": custom_features,
245-
"custom_analyses": custom_analyses,
246-
}
247-
248-
return processor.get_overall_statistics(
249-
metadata=processor_metadata, sys_output=system_output_data.samples
250-
)
251-
252212
@staticmethod
253213
def _find_output_or_case_raw(
254214
system_id: str, analysis_level: str, output_ids: list[int] | None
@@ -290,6 +250,7 @@ def _validate_and_create_system():
290250
user = get_user()
291251
system["creator"] = user.id
292252
system["preferred_username"] = user.preferred_username
253+
system["created_at"] = system["last_modified"] = datetime.utcnow()
293254

294255
if metadata.dataset_metadata_id:
295256
if not metadata.dataset_split:
@@ -336,119 +297,25 @@ def _validate_and_create_system():
336297
system = _validate_and_create_system()
337298

338299
try:
339-
# -- find the dataset and grab custom features if they exist
340-
dataset_custom_features = {}
341-
if system.dataset:
342-
dataset_info = DatasetDBUtils.find_dataset_by_id(
343-
system.dataset.dataset_id
344-
)
345-
if dataset_info and dataset_info.custom_features:
346-
dataset_custom_features = dict(dataset_info.custom_features)
347-
348300
# -- load the system output into memory from the uploaded file(s)
349301
system_output_data = SystemDBUtils._load_sys_output(
350-
system, metadata, system_output, custom_dataset, dataset_custom_features
351-
)
352-
system_custom_features: dict = (
353-
system_output_data.metadata.custom_features or {}
302+
system, system_output, custom_dataset
354303
)
355-
custom_analyses: list = system_output_data.metadata.custom_analyses or []
356-
357-
# -- combine custom features from the two sources
358-
custom_features = dict(system_custom_features)
359-
custom_features.update(dataset_custom_features)
360-
361-
# -- do the actual analysis and binarize the metric stats
362-
overall_statistics = SystemDBUtils._process(
363-
system, metadata, system_output_data, custom_features, custom_analyses
364-
)
365-
binarized_metric_stats = [
366-
[binarize_bson(y.get_data()) for y in x]
367-
for x in overall_statistics.metric_stats
368-
]
369-
370-
# -- add the analysis results to the system object
371-
sys_info = overall_statistics.sys_info
372-
373-
if sys_info.analysis_levels and sys_info.results.overall:
374-
# store overall metrics in the DB so they can be queried
375-
for level, result in zip(
376-
sys_info.analysis_levels, sys_info.results.overall
377-
):
378-
system.results[level.name] = {}
379-
for metric_result in result:
380-
system.results[level.name][
381-
metric_result.metric_name
382-
] = metric_result.value
383-
384-
def db_operations(session: ClientSession) -> str:
385-
# Insert system
386-
system.created_at = system.last_modified = datetime.utcnow()
387-
document = general_to_dict(system)
388-
document.pop("system_id")
389-
document.pop("preferred_username")
390-
document["system_info"] = sys_info.to_dict()
391-
document["metric_stats"] = binarized_metric_stats
392-
system_id = DBUtils.insert_one(
393-
DBUtils.DEV_SYSTEM_METADATA, document, session=session
394-
)
395-
# Compress the system output and upload to Cloud Storage
396-
insert_list = []
397-
sample_list = [general_to_dict(v) for v in system_output_data.samples]
398-
399-
blob_name = f"{system_id}/{SystemDBUtils._SYSTEM_OUTPUT_CONST}"
400-
get_storage().compress_and_upload(
401-
blob_name,
402-
json.dumps(sample_list),
403-
)
404-
insert_list.append(
405-
{
406-
"system_id": system_id,
407-
"analysis_level": SystemDBUtils._SYSTEM_OUTPUT_CONST,
408-
"data": blob_name,
409-
}
410-
)
411-
# Compress analysis cases
412-
for i, (analysis_level, analysis_cases) in enumerate(
413-
zip(
414-
overall_statistics.sys_info.analysis_levels,
415-
overall_statistics.analysis_cases,
416-
)
417-
):
418-
case_list = [general_to_dict(v) for v in analysis_cases]
419-
420-
blob_name = f"{system_id}/{analysis_level.name}"
421-
get_storage().compress_and_upload(blob_name, json.dumps(case_list))
422-
insert_list.append(
423-
{
424-
"system_id": system_id,
425-
"analysis_level": analysis_level.name,
426-
"data": blob_name,
427-
}
428-
)
429-
# Insert system output and analysis cases
430-
output_collection = DBUtils.get_system_output_collection(system_id)
431-
result = DBUtils.insert_many(
432-
output_collection, insert_list, False, session
433-
)
434-
if not result:
435-
abort_with_error_message(
436-
400, f"failed to insert outputs for {system_id}"
437-
)
438-
return system_id
439-
440-
# -- perform upload to the DB
441-
system_id = DBUtils.execute_transaction(db_operations)
442-
system.system_id = system_id
443-
444-
# -- return the system
445-
return system
446304
except ValueError as e:
447305
traceback.print_exc()
448306
abort_with_error_message(400, str(e))
449307
# mypy doesn't seem to understand the NoReturn type in an except block.
450308
# It's a noop to fix it
451309
raise e
310+
else:
311+
312+
def db_operations(session: ClientSession) -> None:
313+
system.save_to_db(session)
314+
system.save_system_output(system_output_data, session)
315+
system.update_overall_statistics(metadata, system_output_data, session)
316+
317+
DBUtils.execute_transaction(db_operations)
318+
return system
452319

453320
@staticmethod
454321
def update_system_by_id(system_id: str, metadata: SystemMetadataUpdatable) -> bool:
@@ -506,7 +373,7 @@ def find_system_outputs(
506373
find multiple system outputs whose ids are in output_ids
507374
"""
508375
sys_data = SystemDBUtils._find_output_or_case_raw(
509-
str(system_id), SystemDBUtils._SYSTEM_OUTPUT_CONST, output_ids
376+
str(system_id), SystemModel._SYSTEM_OUTPUT_CONST, output_ids
510377
)
511378
return [SystemDBUtils.system_output_from_dict(doc) for doc in sys_data]
512379

0 commit comments

Comments
 (0)