Skip to content

Commit

Permalink
implement update_overall_statistics() and refactor create_system()
Browse files Browse the repository at this point in the history
  • Loading branch information
lyuyangh committed Nov 2, 2022
1 parent eb9987a commit 67038ad
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 164 deletions.
26 changes: 18 additions & 8 deletions backend/src/impl/db_utils/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Optional, TypeVar
from typing import TypeVar

from bson.objectid import InvalidId, ObjectId
from explainaboard_web.impl.db import get_db
Expand Down Expand Up @@ -88,7 +90,10 @@ def insert_many(

@staticmethod
def find_one_by_id(
collection: DBCollection, docid: str, projection: Optional[dict] = None
collection: DBCollection,
docid: str,
projection: dict | None = None,
session: ClientSession | None = None,
):
"""
Find and return a document with the _id field
Expand All @@ -102,11 +107,16 @@ def find_one_by_id(
"""Mongo accepts custom ID"""
_id = docid
finally:
return DBUtils.get_collection(collection).find_one({"_id": _id}, projection)
return DBUtils.get_collection(collection).find_one(
{"_id": _id}, projection, session=session
)

@staticmethod
def update_one_by_id(
collection: DBCollection, docid: str, field_to_value: dict
collection: DBCollection,
docid: str,
field_to_value: dict,
session: ClientSession | None = None,
) -> bool:
"""
Update a document with the _id field
Expand All @@ -117,7 +127,7 @@ def update_one_by_id(
"""
try:
result: UpdateResult = DBUtils.get_collection(collection).update_one(
{"_id": ObjectId(docid)}, {"$set": field_to_value}
{"_id": ObjectId(docid)}, {"$set": field_to_value}, session=session
)
if int(result.modified_count) == 1:
return True
Expand Down Expand Up @@ -179,11 +189,11 @@ def count(collection: DBCollection, filt: dict = None) -> int:
@staticmethod
def find(
collection: DBCollection,
filt: Optional[dict] = None,
sort: Optional[list] = None,
filt: dict | None = None,
sort: list | None = None,
skip=0,
limit: int = 10,
projection: Optional[dict] = None,
projection: dict | None = None,
) -> tuple[Cursor, int]:
"""
Find multiple documents
Expand Down
171 changes: 19 additions & 152 deletions backend/src/impl/db_utils/system_db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@
from typing import Any, NamedTuple

from bson import ObjectId
from explainaboard import DatalabLoaderOption, FileType, Source, get_processor
from explainaboard.loaders.file_loader import FileLoaderReturn
from explainaboard import DatalabLoaderOption, FileType, Source
from explainaboard.loaders.loader_registry import get_loader_class
from explainaboard.serialization.legacy import general_to_dict
from explainaboard_web.impl.auth import get_user
from explainaboard_web.impl.db_utils.dataset_db_utils import DatasetDBUtils
from explainaboard_web.impl.db_utils.db_utils import DBUtils
from explainaboard_web.impl.db_utils.user_db_utils import UserDBUtils
from explainaboard_web.impl.internal_models.system_model import SystemModel
from explainaboard_web.impl.storage import get_storage
from explainaboard_web.impl.utils import abort_with_error_message, binarize_bson
from explainaboard_web.impl.utils import abort_with_error_message
from explainaboard_web.models import (
AnalysisCase,
System,
Expand All @@ -33,7 +31,6 @@
class SystemDBUtils:

_COLON_RE = r"^([A-Za-z0-9_-]+): (.+)$"
_SYSTEM_OUTPUT_CONST = "__SYSOUT__"

@staticmethod
def _parse_colon_line(line) -> tuple[str, str]:
Expand Down Expand Up @@ -178,17 +175,15 @@ def find_systems(

@staticmethod
def _load_sys_output(
system: System,
metadata: SystemMetadata,
system: SystemModel,
system_output: SystemOutputProps,
custom_dataset: SystemOutputProps | None,
dataset_custom_features: dict,
):
"""
Load the system output from the uploaded file
"""
if custom_dataset:
return get_loader_class(task=metadata.task)(
return get_loader_class(task=system.task)(
dataset_data=custom_dataset.data,
output_data=system_output.data,
dataset_source=Source.in_memory,
Expand All @@ -198,13 +193,13 @@ def _load_sys_output(
).load()
elif system.dataset:
return (
get_loader_class(task=metadata.task)
get_loader_class(task=system.task)
.from_datalab(
dataset=DatalabLoaderOption(
system.dataset.dataset_name,
system.dataset.sub_dataset,
metadata.dataset_split,
custom_features=dataset_custom_features,
system.dataset.split,
custom_features=system.get_dataset_custom_features(),
),
output_data=system_output.data,
output_file_type=FileType(system_output.file_type),
Expand All @@ -214,41 +209,6 @@ def _load_sys_output(
)
raise ValueError("neither dataset or custom_dataset is available")

@staticmethod
def _process(
system: System,
metadata: SystemMetadata,
system_output_data: FileLoaderReturn,
custom_features: dict,
custom_analyses: list,
):
processor = get_processor(metadata.task)
metrics_lookup = {
metric.name: metric
for metric in get_processor(metadata.task).full_metric_list()
}
metric_configs = []
for metric_name in metadata.metric_names:
if metric_name not in metrics_lookup:
abort_with_error_message(
400, f"{metric_name} is not a supported metric"
)
metric_configs.append(metrics_lookup[metric_name])
processor_metadata = {
**metadata.to_dict(),
"dataset_name": system.dataset.dataset_name if system.dataset else None,
"sub_dataset_name": system.dataset.sub_dataset if system.dataset else None,
"dataset_split": metadata.dataset_split,
"task_name": metadata.task,
"metric_configs": metric_configs,
"custom_features": custom_features,
"custom_analyses": custom_analyses,
}

return processor.get_overall_statistics(
metadata=processor_metadata, sys_output=system_output_data.samples
)

@staticmethod
def _find_output_or_case_raw(
system_id: str, analysis_level: str, output_ids: list[int] | None
Expand Down Expand Up @@ -290,6 +250,7 @@ def _validate_and_create_system():
user = get_user()
system["creator"] = user.id
system["preferred_username"] = user.preferred_username
system["created_at"] = system["last_modified"] = datetime.utcnow()

if metadata.dataset_metadata_id:
if not metadata.dataset_split:
Expand Down Expand Up @@ -336,119 +297,25 @@ def _validate_and_create_system():
system = _validate_and_create_system()

try:
# -- find the dataset and grab custom features if they exist
dataset_custom_features = {}
if system.dataset:
dataset_info = DatasetDBUtils.find_dataset_by_id(
system.dataset.dataset_id
)
if dataset_info and dataset_info.custom_features:
dataset_custom_features = dict(dataset_info.custom_features)

# -- load the system output into memory from the uploaded file(s)
system_output_data = SystemDBUtils._load_sys_output(
system, metadata, system_output, custom_dataset, dataset_custom_features
)
system_custom_features: dict = (
system_output_data.metadata.custom_features or {}
system, system_output, custom_dataset
)
custom_analyses: list = system_output_data.metadata.custom_analyses or []

# -- combine custom features from the two sources
custom_features = dict(system_custom_features)
custom_features.update(dataset_custom_features)

# -- do the actual analysis and binarize the metric stats
overall_statistics = SystemDBUtils._process(
system, metadata, system_output_data, custom_features, custom_analyses
)
binarized_metric_stats = [
[binarize_bson(y.get_data()) for y in x]
for x in overall_statistics.metric_stats
]

# -- add the analysis results to the system object
sys_info = overall_statistics.sys_info

if sys_info.analysis_levels and sys_info.results.overall:
# store overall metrics in the DB so they can be queried
for level, result in zip(
sys_info.analysis_levels, sys_info.results.overall
):
system.results[level.name] = {}
for metric_result in result:
system.results[level.name][
metric_result.metric_name
] = metric_result.value

def db_operations(session: ClientSession) -> str:
# Insert system
system.created_at = system.last_modified = datetime.utcnow()
document = general_to_dict(system)
document.pop("system_id")
document.pop("preferred_username")
document["system_info"] = sys_info.to_dict()
document["metric_stats"] = binarized_metric_stats
system_id = DBUtils.insert_one(
DBUtils.DEV_SYSTEM_METADATA, document, session=session
)
# Compress the system output and upload to Cloud Storage
insert_list = []
sample_list = [general_to_dict(v) for v in system_output_data.samples]

blob_name = f"{system_id}/{SystemDBUtils._SYSTEM_OUTPUT_CONST}"
get_storage().compress_and_upload(
blob_name,
json.dumps(sample_list),
)
insert_list.append(
{
"system_id": system_id,
"analysis_level": SystemDBUtils._SYSTEM_OUTPUT_CONST,
"data": blob_name,
}
)
# Compress analysis cases
for i, (analysis_level, analysis_cases) in enumerate(
zip(
overall_statistics.sys_info.analysis_levels,
overall_statistics.analysis_cases,
)
):
case_list = [general_to_dict(v) for v in analysis_cases]

blob_name = f"{system_id}/{analysis_level.name}"
get_storage().compress_and_upload(blob_name, json.dumps(case_list))
insert_list.append(
{
"system_id": system_id,
"analysis_level": analysis_level.name,
"data": blob_name,
}
)
# Insert system output and analysis cases
output_collection = DBUtils.get_system_output_collection(system_id)
result = DBUtils.insert_many(
output_collection, insert_list, False, session
)
if not result:
abort_with_error_message(
400, f"failed to insert outputs for {system_id}"
)
return system_id

# -- perform upload to the DB
system_id = DBUtils.execute_transaction(db_operations)
system.system_id = system_id

# -- return the system
return system
except ValueError as e:
traceback.print_exc()
abort_with_error_message(400, str(e))
# mypy doesn't seem to understand the NoReturn type in an except block.
# It's a noop to fix it
raise e
else:

def db_operations(session: ClientSession) -> None:
system.save_to_db(session)
system.save_system_output(system_output_data, session)
system.update_overall_statistics(metadata, system_output_data, session)

DBUtils.execute_transaction(db_operations)
return system

@staticmethod
def update_system_by_id(system_id: str, metadata: SystemMetadataUpdatable) -> bool:
Expand Down Expand Up @@ -506,7 +373,7 @@ def find_system_outputs(
find multiple system outputs whose ids are in output_ids
"""
sys_data = SystemDBUtils._find_output_or_case_raw(
str(system_id), SystemDBUtils._SYSTEM_OUTPUT_CONST, output_ids
str(system_id), SystemModel._SYSTEM_OUTPUT_CONST, output_ids
)
return [SystemDBUtils.system_output_from_dict(doc) for doc in sys_data]

Expand Down
Loading

0 comments on commit 67038ad

Please sign in to comment.