-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge/model db upload val #205
Changes from all commits
bd436f3
c2e50be
9b5e2ea
63c2bb1
c380573
1b4edfe
00018f2
e7d012e
87212fc
3a6f693
055a55c
3a9f04b
41967ca
e065540
6a54e9b
bb88202
01571c5
fb906b3
56bc6a8
a530304
8a60f0c
a423336
8e9fe4b
0000070
a0fb0b7
5ccd396
450a07e
f4c8c27
cfd78fa
f8ea4c3
30a4d8d
4088512
687fe9f
8964492
a240a6c
7684bde
68d16cd
f3f04a8
e8429f2
e3a08a1
eab8f7a
bbbd56c
552b620
3cd7281
12de854
9189ca3
a870b7a
0bd0339
8d31540
d2c9f2e
7fff551
d21da8f
47a9d4b
6b9b029
1af97cf
3f41697
92e0ffa
9061a12
367efbb
2b9986b
4dff6d2
e5174dc
9d28e6c
d958b86
4dcd1aa
7dc199e
e028498
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,7 @@ | ||
import os | ||
import json | ||
import string | ||
import uuid | ||
import torch | ||
import io | ||
import contextlib | ||
import traceback | ||
from flask import request | ||
from flask import Blueprint | ||
from datetime import datetime | ||
from utils.model_utils import * | ||
from objects.RResponse import RResponse | ||
from objects.RServer import RServer | ||
|
@@ -96,6 +89,12 @@ def UploadModel(): | |
type: "string" | ||
description: "A nickname for the model." | ||
required: true | ||
predefined: | ||
type: "string" | ||
description: | | ||
Indicates if a predefined model is being used. | ||
"1" represents predefined, "0" otherwise. | ||
required: true | ||
description: | ||
type: "string" | ||
description: "A description of the model (optional)." | ||
|
@@ -107,11 +106,11 @@ def UploadModel(): | |
pretrained: | ||
type: "string" | ||
description: | | ||
Indicates if a predefined model is being used. | ||
"1" represents pretrained, "0" otherwise (required only if users choose to use a predefined model). | ||
Indicates whether the model is pretrained. | ||
Should only be set to "1" if the model is predefined and pretrained. | ||
num_classes: | ||
type: "integer" | ||
description: "The number of classes (required only if users choose to use a predefined model)." | ||
type: "string" | ||
description: "The number of classes for the predefined model (will soon be removed)." | ||
|
||
responses: | ||
200: | ||
|
@@ -139,160 +138,125 @@ def UploadModel(): | |
example: "Success" | ||
|
||
""" | ||
# Get the model's metadata | ||
metadata = request.form | ||
if not metadata: | ||
RResponse.abort(400, "Empty request received") | ||
|
||
# Check if the folder for saving models exists, if not, create it | ||
base_dir = RServer.get_server().base_dir | ||
models_dir = os.path.join(base_dir, "generated", "models") | ||
if not os.path.exists(models_dir): | ||
os.makedirs(models_dir) | ||
|
||
print("Requested to upload a new model") | ||
code_path = None | ||
weight_path = None | ||
try: | ||
# Precheck the request | ||
errors = precheck_request_4_upload_model(request) | ||
if len(errors) > 0: | ||
error_message = "; ".join(errors) | ||
return RResponse.fail(f"Request validation failed: {error_message}", 400) | ||
|
||
# Generate a uuid for the model saving | ||
saving_id = str(uuid.uuid4()) | ||
metadata_str = request.form.get("metadata") | ||
metadata = json.loads(metadata_str) | ||
|
||
code_path = os.path.join(base_dir, "generated", "models", f"{saving_id}.py") | ||
# Create the code and ckpt directories if they don't exist | ||
create_models_dir() | ||
|
||
metadata_4_save = { | ||
"class_name": None, | ||
"nickname": None, | ||
"description": None, | ||
"architecture": None, | ||
"tags": None, | ||
"create_time": None, | ||
"weight_path": None, | ||
"code_path": None, | ||
"epoch": None, | ||
"train_accuracy": None, | ||
"val_accuracy": None, | ||
"test_accuracy": None, | ||
"last_eval_on_dev_set": None, | ||
"last_eval_on_test_set": None, | ||
} | ||
print("Requested to upload a new model") | ||
|
||
# Get the model's class name | ||
class_name = metadata.get("class_name") | ||
# Generate a uuid for the model saving | ||
saving_id = str(uuid.uuid4()) | ||
|
||
# If the model is custom(i.e. it has code definition) | ||
if "code" in request.form: | ||
# Get the model definition code and save it to a temporary file | ||
code = request.form.get("code") | ||
try: | ||
with open(code_path, "w") as code_file: | ||
code_file.write(code) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(base_dir, saving_id) | ||
return RResponse.abort(500, f"Failed to save the model definition. {e}") | ||
code_path = os.path.join( | ||
RServer.get_server().base_dir, | ||
"generated", | ||
"models", | ||
"code", | ||
f"{saving_id}.py", | ||
) | ||
|
||
# Initialize the custom model | ||
try: | ||
model = RModelWrapper.init_custom_model( | ||
code_path, class_name, RServer.get_model_wrapper().device | ||
) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(base_dir, saving_id) | ||
return RResponse.abort(400, f"Failed to initialize the custom model. {e}") | ||
elif "pretrained" in metadata: # If the model is predefined | ||
pretrained = bool(int(metadata.get("pretrained"))) | ||
num_classes = int(metadata.get("num_classes")) | ||
try: | ||
model = init_predefined_model(class_name, pretrained, num_classes) | ||
with open(code_path, "w") as code_file: | ||
code_file.write(f"num_classes = {num_classes}") | ||
except Exception as e: | ||
traceback.print_exc() | ||
return RResponse.abort( | ||
400, f"Failed to initialize the predefined model. {e}" | ||
# Get the model's class name | ||
class_name = metadata.get("class_name") | ||
|
||
predefined = bool(int(metadata.get("predefined"))) | ||
|
||
# Save the model's code definition and initialize the model | ||
if not predefined: # If the model is custom | ||
# Get the model definition code and save it to a temporary file | ||
code = request.form.get("code") | ||
save_code(code, code_path) | ||
# Initialize the model | ||
try: | ||
model = RModelWrapper.init_custom_model( | ||
code_path, class_name, RServer.get_model_wrapper().device | ||
) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(code_path, weight_path) | ||
return RResponse.fail( | ||
f"Failed to initialize the custom model. {e}", 400 | ||
) | ||
elif predefined: # If the model is predefined | ||
pretrained = bool(int(metadata.get("pretrained"))) | ||
num_classes = int(metadata.get("num_classes")) | ||
# TODO: Stop relying on the user to provide the number of classes | ||
code = f"num_classes = {num_classes}" | ||
save_code(code, code_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for pre-defined models we don't need to save any code right? Looks like you are trying to save some dummy string which is never used? Shall we remove? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think after initializing this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will set this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. Thanks @Leon-Leyang |
||
try: | ||
model = RModelWrapper.init_predefined_model( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we load a model, we will always check model weight path and:
If a model is pre-defined and pre-trained, when we load, we also want to load the weights. The question is, how do we know a model is pre-trained or not when pulling it out from DB? There is no Two ways:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I will add |
||
class_name, | ||
pretrained, | ||
num_classes, | ||
RServer.get_model_wrapper().device, | ||
) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(code_path, weight_path) | ||
return RResponse.fail( | ||
f"Failed to initialize the predefined model. {e}", 400 | ||
) | ||
else: | ||
return RResponse.fail( | ||
"Invalid request. The model is neither custom nor predefined.", 400 | ||
) | ||
else: | ||
return RResponse.abort(400, "The model is neither custom nor predefined.") | ||
|
||
# Get the weight file and save it to a temporary location if it exists | ||
if "weight_file" in request.files: | ||
weight_file = request.files.get("weight_file") | ||
try: | ||
# Get the weight file and save it to a temporary location if it exists | ||
if "weight_file" in request.files: | ||
weight_path = os.path.join( | ||
RServer.get_server().base_dir, "generated", "models", f"{saving_id}.pth" | ||
RServer.get_server().base_dir, | ||
"generated", | ||
"models", | ||
"ckpt", | ||
f"{saving_id}.pth", | ||
) | ||
weight_file.save(weight_path) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(base_dir, saving_id) | ||
return RResponse.abort(500, f"Failed to save the weight file. {e}") | ||
|
||
# Load the weights from the file | ||
try: | ||
model.load_state_dict(torch.load(weight_path)) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(base_dir, saving_id) | ||
return RResponse.abort(400, f"Failed to load the weights. {e}") | ||
else: # If the weight file is not provided, save the current weights to a temporary location | ||
weight_file = request.files.get("weight_file") | ||
save_ckpt_weight(weight_file, weight_path) | ||
# Load and validate the weights from the file | ||
try: | ||
load_ckpt_weight(model, weight_path) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(code_path, weight_path) | ||
return RResponse.fail(f"Failed to load the weights. {e}", 400) | ||
|
||
# Validate the model | ||
try: | ||
weight_path = os.path.join( | ||
RServer.get_server().base_dir, "generated", "models", f"{saving_id}.pth" | ||
dummy_model_wrapper = DummyModelWrapper( | ||
model, RServer.get_model_wrapper().device | ||
) | ||
torch.save(model.state_dict(), weight_path) | ||
val_model(dummy_model_wrapper) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(base_dir, saving_id) | ||
return RResponse.abort(500, f"Failed to save the weight file. {e}") | ||
clear_model_temp_files(code_path, weight_path) | ||
return RResponse.fail(f"The model is invalid. {e}", 400) | ||
|
||
# Validate the model | ||
try: | ||
dummy_model_wrapper = DummyModelWrapper( | ||
model, RServer.get_model_wrapper().device | ||
# Construct the metadata for saving | ||
metadata_4_save = construct_metadata_4_save( | ||
metadata, code_path, weight_path, model | ||
) | ||
val_model(RServer.get_data_manager(), dummy_model_wrapper) | ||
except Exception as e: | ||
traceback.print_exc() | ||
clear_model_temp_files(base_dir, saving_id) | ||
return RResponse.abort(400, f"The model is invalid. {e}") | ||
|
||
# Update the metadata for saving | ||
metadata_4_save["class_name"] = class_name | ||
metadata_4_save["nickname"] = metadata.get("nickname") | ||
metadata_4_save["description"] = ( | ||
metadata.get("description") if metadata.get("description") else None | ||
) | ||
metadata_4_save["tags"] = metadata.get("tags") if metadata.get("tags") else None | ||
metadata_4_save["create_time"] = datetime.now() | ||
metadata_4_save["code_path"] = code_path | ||
metadata_4_save["weight_path"] = weight_path | ||
metadata_4_save["epoch"] = 0 | ||
metadata_4_save["train_accuracy"] = None | ||
metadata_4_save["val_accuracy"] = None | ||
metadata_4_save["test_accuracy"] = None | ||
metadata_4_save["last_eval_on_dev_set"] = None | ||
metadata_4_save["last_eval_on_test_set"] = None | ||
|
||
# Save the model's architecture to the metadata | ||
buffer = io.StringIO() | ||
with contextlib.redirect_stdout(buffer): | ||
print(model) | ||
metadata_4_save["architecture"] = buffer.getvalue() | ||
|
||
# Save the model's metadata to the database | ||
try: | ||
# Save the model's metadata to the database | ||
RServer.get_model_wrapper().create_model(metadata_4_save) | ||
except Exception as e: | ||
traceback.print_exc() | ||
return RResponse.abort(500, f"Failed to save the model. {e}") | ||
|
||
# Set the current model to the newly uploaded model | ||
try: | ||
# Set the current model to the newly uploaded model | ||
RServer.get_model_wrapper().set_current_model(metadata.get("nickname")) | ||
|
||
return RResponse.ok("Success") | ||
except Exception as e: | ||
traceback.print_exc() | ||
return RResponse.abort(500, f"Failed to set the current model. {e}") | ||
|
||
return RResponse.ok("Success") | ||
clear_model_temp_files(code_path, weight_path) | ||
return RResponse.abort(500, f"Unexpected error. {e}") | ||
|
||
|
||
@model_api.route("/model/list", methods=["GET"]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: is there a reason why we want to wrap all parameters in
metadata
field?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Emm... I just did it because I thought it made the data look more organized. Flattening it shall also work.