Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge/model db upload val #205

Merged
merged 67 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
bd436f3
Add db setup doc; remove useless db files
PaulCCCCCCH Jul 25, 2023
c2e50be
add database schema change note
PaulCCCCCCH Jul 27, 2023
9b5e2ea
add db schema design diagram and code
PaulCCCCCCH Jul 27, 2023
63c2bb1
Create 'model.py' for initial model APIs sketch
Leon-Leyang Jul 30, 2023
c380573
Add docstrings for the UploadModel API
Leon-Leyang Jul 30, 2023
1b4edfe
Add 'model_utils.py' for utilities & outline 'UploadModel' API
Leon-Leyang Jul 30, 2023
00018f2
Get all metadata of the model according to the db design
Leon-Leyang Aug 7, 2023
e7d012e
Complete `init_model` function implementation
Leon-Leyang Aug 7, 2023
87212fc
Extend `UploadModel` to handle uploading models after training
Leon-Leyang Aug 8, 2023
3a6f693
refactor: implemented orm models
PaulCCCCCCH Aug 14, 2023
055a55c
fixed db model definition bugs
PaulCCCCCCH Aug 14, 2023
3a9f04b
change db types in function signatures
PaulCCCCCCH Aug 20, 2023
41967ca
Enhance UploadModel API with ID assignment and architecture metadata
Leon-Leyang Aug 27, 2023
e065540
Complete `val_model` function implementation
Leon-Leyang Aug 28, 2023
6a54e9b
fix: new database operations with SQLAlchemy partially working
PaulCCCCCCH Aug 28, 2023
bb88202
most functionalities working; not passing test cases
PaulCCCCCCH Aug 31, 2023
01571c5
test cases passing
PaulCCCCCCH Aug 31, 2023
fb906b3
chore: update docker image version; include SQLAlchemy dependency
PaulCCCCCCH Aug 31, 2023
56bc6a8
feat: implemented model operations
PaulCCCCCCH Aug 31, 2023
a530304
Update save path for model definition and checkpoint files to /genera…
Leon-Leyang Sep 3, 2023
8a60f0c
Refactor 'UploadModel' API for enhanced model handling
Leon-Leyang Sep 3, 2023
a423336
Migrate from importlib_resources to importlib.resources
Leon-Leyang Sep 3, 2023
8e9fe4b
Refresh requirements.txt with updated packge versions
Leon-Leyang Sep 3, 2023
0000070
Merge branch 'dev' of https://github.com/HaohanWang/Robustar_implemen…
Leon-Leyang Sep 9, 2023
a0fb0b7
Merge branch 'chonghan/model-db' into 'merge/model-db-upload-val'
Leon-Leyang Sep 9, 2023
5ccd396
Merge branch 'leyang/model-upload-val' into merge/model-db-upload-val
Leon-Leyang Sep 9, 2023
450a07e
Correct some mistakes to let 'UploadModel' API work for basic custom …
Leon-Leyang Sep 10, 2023
f4c8c27
Fix type errors of var `pretrained` and `num_classes`
Leon-Leyang Sep 10, 2023
cfd78fa
Save the weight of pretrained predefined model to a local file
Leon-Leyang Sep 10, 2023
f8ea4c3
Raise an exception if pretrained weight is required for "ResNet18-32x32"
Leon-Leyang Sep 10, 2023
30a4d8d
Save `num_classes` of predefined models into local files for later re…
Leon-Leyang Sep 10, 2023
4088512
Save current model weights to a local file if the weight file is not …
Leon-Leyang Sep 11, 2023
687fe9f
Comment db clean up temporarily.
Leon-Leyang Sep 17, 2023
8964492
Permit optional transmission of description and tags in metadata from…
Leon-Leyang Sep 22, 2023
a240a6c
Update model metadata design: Introduce `nickname` and Rename `name` …
Leon-Leyang Sep 22, 2023
7684bde
Extract 'tags' column from 'Models' table into a new table, and maint…
Leon-Leyang Sep 22, 2023
68d16cd
Enhance `UploadModel` docstring documentation
Leon-Leyang Sep 22, 2023
f3f04a8
Close DB connection explicitly for .db file removal; Set fixture scop…
Leon-Leyang Oct 1, 2023
e8429f2
Change the base directory in pytest for debugging
Leon-Leyang Oct 1, 2023
e3a08a1
Debugging
Leon-Leyang Oct 1, 2023
eab8f7a
Debugging
Leon-Leyang Oct 1, 2023
bbbd56c
Debugging
Leon-Leyang Oct 1, 2023
552b620
Debugging
Leon-Leyang Oct 1, 2023
3cd7281
Debugging
Leon-Leyang Oct 1, 2023
12de854
Debugging
Leon-Leyang Oct 1, 2023
9189ca3
Restore for a full test
Leon-Leyang Oct 1, 2023
a870b7a
Restore for full test
Leon-Leyang Oct 1, 2023
0bd0339
Capture the 'stdout' and 'stderr' in pytest for debugging
Leon-Leyang Oct 1, 2023
8d31540
Debugging
Leon-Leyang Oct 1, 2023
d2c9f2e
Debugging
Leon-Leyang Oct 1, 2023
7fff551
Debugging
Leon-Leyang Oct 1, 2023
d21da8f
Solve the problem by specifying the db file address as absolute path
Leon-Leyang Oct 1, 2023
47a9d4b
Restore the config of CircleCI test
Leon-Leyang Oct 1, 2023
6b9b029
fix: provide app context to all thread creations
PaulCCCCCCH Oct 2, 2023
1af97cf
Merge branch 'merge/model-db-upload-val' of github.com:HaohanWang/Rob…
PaulCCCCCCH Oct 2, 2023
3f41697
Merge branch 'dev' into merge/model-db-upload-val
PaulCCCCCCH Oct 10, 2023
92e0ffa
fix: correctly getting context for child threads
PaulCCCCCCH Oct 10, 2023
9061a12
Resolve code review feedback:
Leon-Leyang Oct 15, 2023
367efbb
Merge remote-tracking branch 'origin/merge/model-db-upload-val' into …
Leon-Leyang Oct 15, 2023
2b9986b
Add `predefined` to the model's metadata
Leon-Leyang Oct 29, 2023
4dff6d2
Resolve conflict
Leon-Leyang Oct 31, 2023
e5174dc
Cherry-pick model refactors from refactor/model-db-upload-val
Leon-Leyang Nov 1, 2023
9d28e6c
Add one TODO
Leon-Leyang Nov 3, 2023
d958b86
Remove one TODO
Leon-Leyang Nov 3, 2023
4dcd1aa
Merge branch 'merge/v0.3' into 'merge/model-db-upload-val'
Leon-Leyang Nov 5, 2023
7dc199e
Add `pretrained` as a new field in 'models' table & Stop saving weigh…
Leon-Leyang Nov 6, 2023
e028498
Set `pretrained` as required field for metadata
Leon-Leyang Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 107 additions & 143 deletions back-end/apis/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
import os
import json
import string
import uuid
import torch
import io
import contextlib
import traceback
from flask import request
from flask import Blueprint
from datetime import datetime
from utils.model_utils import *
from objects.RResponse import RResponse
from objects.RServer import RServer
Expand Down Expand Up @@ -96,6 +89,12 @@ def UploadModel():
type: "string"
description: "A nickname for the model."
required: true
predefined:
type: "string"
description: |
Indicates if a predefined model is being used.
"1" represents predefined, "0" otherwise.
required: true
description:
type: "string"
description: "A description of the model (optional)."
Expand All @@ -107,11 +106,11 @@ def UploadModel():
pretrained:
type: "string"
description: |
Indicates if a predefined model is being used.
"1" represents pretrained, "0" otherwise (required only if users choose to use a predefined model).
Indicates whether the model is pretrained.
Should only be set to "1" if the model is predefined and pretrained.
num_classes:
type: "integer"
description: "The number of classes (required only if users choose to use a predefined model)."
type: "string"
description: "The number of classes for the predefined model (will soon be removed)."

responses:
200:
Expand Down Expand Up @@ -139,160 +138,125 @@ def UploadModel():
example: "Success"

"""
# Get the model's metadata
metadata = request.form
if not metadata:
RResponse.abort(400, "Empty request received")

# Check if the folder for saving models exists, if not, create it
base_dir = RServer.get_server().base_dir
models_dir = os.path.join(base_dir, "generated", "models")
if not os.path.exists(models_dir):
os.makedirs(models_dir)

print("Requested to upload a new model")
code_path = None
weight_path = None
try:
# Precheck the request
errors = precheck_request_4_upload_model(request)
if len(errors) > 0:
error_message = "; ".join(errors)
return RResponse.fail(f"Request validation failed: {error_message}", 400)

# Generate a uuid for the model saving
saving_id = str(uuid.uuid4())
metadata_str = request.form.get("metadata")
metadata = json.loads(metadata_str)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: is there a reason why we want to wrap all parameters in metadata field?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm... I just did it because I thought it made the data look more organized. Flattening it shall also work.


code_path = os.path.join(base_dir, "generated", "models", f"{saving_id}.py")
# Create the code and ckpt directories if they don't exist
create_models_dir()

metadata_4_save = {
"class_name": None,
"nickname": None,
"description": None,
"architecture": None,
"tags": None,
"create_time": None,
"weight_path": None,
"code_path": None,
"epoch": None,
"train_accuracy": None,
"val_accuracy": None,
"test_accuracy": None,
"last_eval_on_dev_set": None,
"last_eval_on_test_set": None,
}
print("Requested to upload a new model")

# Get the model's class name
class_name = metadata.get("class_name")
# Generate a uuid for the model saving
saving_id = str(uuid.uuid4())

# If the model is custom(i.e. it has code definition)
if "code" in request.form:
# Get the model definition code and save it to a temporary file
code = request.form.get("code")
try:
with open(code_path, "w") as code_file:
code_file.write(code)
except Exception as e:
traceback.print_exc()
clear_model_temp_files(base_dir, saving_id)
return RResponse.abort(500, f"Failed to save the model definition. {e}")
code_path = os.path.join(
RServer.get_server().base_dir,
"generated",
"models",
"code",
f"{saving_id}.py",
)

# Initialize the custom model
try:
model = RModelWrapper.init_custom_model(
code_path, class_name, RServer.get_model_wrapper().device
)
except Exception as e:
traceback.print_exc()
clear_model_temp_files(base_dir, saving_id)
return RResponse.abort(400, f"Failed to initialize the custom model. {e}")
elif "pretrained" in metadata: # If the model is predefined
pretrained = bool(int(metadata.get("pretrained")))
num_classes = int(metadata.get("num_classes"))
try:
model = init_predefined_model(class_name, pretrained, num_classes)
with open(code_path, "w") as code_file:
code_file.write(f"num_classes = {num_classes}")
except Exception as e:
traceback.print_exc()
return RResponse.abort(
400, f"Failed to initialize the predefined model. {e}"
# Get the model's class name
class_name = metadata.get("class_name")

predefined = bool(int(metadata.get("predefined")))

# Save the model's code definition and initialize the model
if not predefined: # If the model is custom
# Get the model definition code and save it to a temporary file
code = request.form.get("code")
save_code(code, code_path)
# Initialize the model
try:
model = RModelWrapper.init_custom_model(
code_path, class_name, RServer.get_model_wrapper().device
)
except Exception as e:
traceback.print_exc()
clear_model_temp_files(code_path, weight_path)
return RResponse.fail(
f"Failed to initialize the custom model. {e}", 400
)
elif predefined: # If the model is predefined
pretrained = bool(int(metadata.get("pretrained")))
num_classes = int(metadata.get("num_classes"))
# TODO: Stop relying on the user to provide the number of classes
code = f"num_classes = {num_classes}"
save_code(code, code_path)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for pre-defined models we don't need to save any code right? Looks like you are trying to save some dummy string which is never used? Shall we remove?

Copy link
Collaborator Author

@Leon-Leyang Leon-Leyang Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think after initializing this num_classes in RDataManager, we can remove this. But currently load_model_by_name still relies on this (I changed the code in load_model_by_name to read this variable from the code file instead of using self.num_classes for a correct loading of the predefined model to pass my manual test).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will set this num_classes in RDataManager after this merge. I have created issue #207 for this.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks @Leon-Leyang

try:
model = RModelWrapper.init_predefined_model(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we load a model, we will always check model weight path and:

  • If weight path is empty, random initialize
  • If weight path exists, load it

If a model is pre-defined and pre-trained, when we load, we also want to load the weights. The question is, how do we know a model is pre-trained or not when pulling it out from DB? There is no pre-trained field is not in the metadata.

Two ways:

  • For pre-defined pre-trained models, we might also need to save model weights. This adds logical consistency, but sacrifices storage (model weight stored twice, once in pip cache, once in robustar).
  • Add another field pre-trained in metadata, and add an additional check for this field when loading. I personally prefer this. This is also a good information to track from user's perspective.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will add pretrained as a new field.

class_name,
pretrained,
num_classes,
RServer.get_model_wrapper().device,
)
except Exception as e:
traceback.print_exc()
clear_model_temp_files(code_path, weight_path)
return RResponse.fail(
f"Failed to initialize the predefined model. {e}", 400
)
else:
return RResponse.fail(
"Invalid request. The model is neither custom nor predefined.", 400
)
else:
return RResponse.abort(400, "The model is neither custom nor predefined.")

# Get the weight file and save it to a temporary location if it exists
if "weight_file" in request.files:
weight_file = request.files.get("weight_file")
try:
# Get the weight file and save it to a temporary location if it exists
if "weight_file" in request.files:
weight_path = os.path.join(
RServer.get_server().base_dir, "generated", "models", f"{saving_id}.pth"
RServer.get_server().base_dir,
"generated",
"models",
"ckpt",
f"{saving_id}.pth",
)
weight_file.save(weight_path)
except Exception as e:
traceback.print_exc()
clear_model_temp_files(base_dir, saving_id)
return RResponse.abort(500, f"Failed to save the weight file. {e}")

# Load the weights from the file
try:
model.load_state_dict(torch.load(weight_path))
except Exception as e:
traceback.print_exc()
clear_model_temp_files(base_dir, saving_id)
return RResponse.abort(400, f"Failed to load the weights. {e}")
else: # If the weight file is not provided, save the current weights to a temporary location
weight_file = request.files.get("weight_file")
save_ckpt_weight(weight_file, weight_path)
# Load and validate the weights from the file
try:
load_ckpt_weight(model, weight_path)
except Exception as e:
traceback.print_exc()
clear_model_temp_files(code_path, weight_path)
return RResponse.fail(f"Failed to load the weights. {e}", 400)

# Validate the model
try:
weight_path = os.path.join(
RServer.get_server().base_dir, "generated", "models", f"{saving_id}.pth"
dummy_model_wrapper = DummyModelWrapper(
model, RServer.get_model_wrapper().device
)
torch.save(model.state_dict(), weight_path)
val_model(dummy_model_wrapper)
except Exception as e:
traceback.print_exc()
clear_model_temp_files(base_dir, saving_id)
return RResponse.abort(500, f"Failed to save the weight file. {e}")
clear_model_temp_files(code_path, weight_path)
return RResponse.fail(f"The model is invalid. {e}", 400)

# Validate the model
try:
dummy_model_wrapper = DummyModelWrapper(
model, RServer.get_model_wrapper().device
# Construct the metadata for saving
metadata_4_save = construct_metadata_4_save(
metadata, code_path, weight_path, model
)
val_model(RServer.get_data_manager(), dummy_model_wrapper)
except Exception as e:
traceback.print_exc()
clear_model_temp_files(base_dir, saving_id)
return RResponse.abort(400, f"The model is invalid. {e}")

# Update the metadata for saving
metadata_4_save["class_name"] = class_name
metadata_4_save["nickname"] = metadata.get("nickname")
metadata_4_save["description"] = (
metadata.get("description") if metadata.get("description") else None
)
metadata_4_save["tags"] = metadata.get("tags") if metadata.get("tags") else None
metadata_4_save["create_time"] = datetime.now()
metadata_4_save["code_path"] = code_path
metadata_4_save["weight_path"] = weight_path
metadata_4_save["epoch"] = 0
metadata_4_save["train_accuracy"] = None
metadata_4_save["val_accuracy"] = None
metadata_4_save["test_accuracy"] = None
metadata_4_save["last_eval_on_dev_set"] = None
metadata_4_save["last_eval_on_test_set"] = None

# Save the model's architecture to the metadata
buffer = io.StringIO()
with contextlib.redirect_stdout(buffer):
print(model)
metadata_4_save["architecture"] = buffer.getvalue()

# Save the model's metadata to the database
try:
# Save the model's metadata to the database
RServer.get_model_wrapper().create_model(metadata_4_save)
except Exception as e:
traceback.print_exc()
return RResponse.abort(500, f"Failed to save the model. {e}")

# Set the current model to the newly uploaded model
try:
# Set the current model to the newly uploaded model
RServer.get_model_wrapper().set_current_model(metadata.get("nickname"))

return RResponse.ok("Success")
except Exception as e:
traceback.print_exc()
return RResponse.abort(500, f"Failed to set the current model. {e}")

return RResponse.ok("Success")
clear_model_temp_files(code_path, weight_path)
return RResponse.abort(500, f"Unexpected error. {e}")


@model_api.route("/model/list", methods=["GET"])
Expand Down
1 change: 1 addition & 0 deletions back-end/apis/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

predict_api = Blueprint("predict_api", __name__)


# Return prediction result
@predict_api.route("/predict/<split>")
def predict(split):
Expand Down
2 changes: 2 additions & 0 deletions back-end/database/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class Models(db.Model):
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
class_name = db.Column(db.String)
nickname = db.Column(db.String)
predefined = db.Column(db.Boolean)
pretrained = db.Column(db.Boolean)
description = db.Column(db.String)
architecture = db.Column(db.String)
tags = db.relationship("Tags", secondary=model_tag_rel, backref="models")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def denormalize(tensor):
def standardize_and_clip(
tensor, min_value=0.0, max_value=1.0, saturation=0.1, brightness=0.5
):

"""Standardizes and clips input tensor.

Standardizes the input tensor (mean = 0.0, std = 1.0). The color saturation
Expand Down
Loading