diff --git a/back-end/apis/model.py b/back-end/apis/model.py index 1e566f1e..d41b0eb2 100644 --- a/back-end/apis/model.py +++ b/back-end/apis/model.py @@ -1,14 +1,7 @@ -import os -import json -import string import uuid -import torch -import io -import contextlib import traceback from flask import request from flask import Blueprint -from datetime import datetime from utils.model_utils import * from objects.RResponse import RResponse from objects.RServer import RServer @@ -96,6 +89,12 @@ def UploadModel(): type: "string" description: "A nickname for the model." required: true + predefined: + type: "string" + description: | + Indicates if a predefined model is being used. + "1" represents predefined, "0" otherwise. + required: true description: type: "string" description: "A description of the model (optional)." @@ -107,11 +106,11 @@ def UploadModel(): pretrained: type: "string" description: | - Indicates if a predefined model is being used. - "1" represents pretrained, "0" otherwise (required only if users choose to use a predefined model). + Indicates whether the model is pretrained. + Should only be set to "1" if the model is predefined and pretrained. num_classes: - type: "integer" - description: "The number of classes (required only if users choose to use a predefined model)." + type: "string" + description: "The number of classes for the predefined model (will soon be removed)." responses: 200: @@ -139,160 +138,125 @@ def UploadModel(): example: "Success" """ - # Get the model's metadata - metadata = request.form - if not metadata: - RResponse.abort(400, "Empty request received") - - # Check if the folder for saving models exists, if not, create it - base_dir = RServer.get_server().base_dir - models_dir = os.path.join(base_dir, "generated", "models") - if not os.path.exists(models_dir): - os.makedirs(models_dir) - - print("Requested to upload a new model") + code_path = None + weight_path = None + try: + # Precheck the request + errors = precheck_request_4_upload_model(request) + if len(errors) > 0: + error_message = "; ".join(errors) + return RResponse.fail(f"Request validation failed: {error_message}", 400) - # Generate a uuid for the model saving - saving_id = str(uuid.uuid4()) + metadata_str = request.form.get("metadata") + metadata = json.loads(metadata_str) - code_path = os.path.join(base_dir, "generated", "models", f"{saving_id}.py") + # Create the code and ckpt directories if they don't exist + create_models_dir() - metadata_4_save = { - "class_name": None, - "nickname": None, - "description": None, - "architecture": None, - "tags": None, - "create_time": None, - "weight_path": None, - "code_path": None, - "epoch": None, - "train_accuracy": None, - "val_accuracy": None, - "test_accuracy": None, - "last_eval_on_dev_set": None, - "last_eval_on_test_set": None, - } + print("Requested to upload a new model") - # Get the model's class name - class_name = metadata.get("class_name") + # Generate a uuid for the model saving + saving_id = str(uuid.uuid4()) - # If the model is custom(i.e. it has code definition) - if "code" in request.form: - # Get the model definition code and save it to a temporary file - code = request.form.get("code") - try: - with open(code_path, "w") as code_file: - code_file.write(code) - except Exception as e: - traceback.print_exc() - clear_model_temp_files(base_dir, saving_id) - return RResponse.abort(500, f"Failed to save the model definition. {e}") + code_path = os.path.join( + RServer.get_server().base_dir, + "generated", + "models", + "code", + f"{saving_id}.py", + ) - # Initialize the custom model - try: - model = RModelWrapper.init_custom_model( - code_path, class_name, RServer.get_model_wrapper().device - ) - except Exception as e: - traceback.print_exc() - clear_model_temp_files(base_dir, saving_id) - return RResponse.abort(400, f"Failed to initialize the custom model. {e}") - elif "pretrained" in metadata: # If the model is predefined - pretrained = bool(int(metadata.get("pretrained"))) - num_classes = int(metadata.get("num_classes")) - try: - model = init_predefined_model(class_name, pretrained, num_classes) - with open(code_path, "w") as code_file: - code_file.write(f"num_classes = {num_classes}") - except Exception as e: - traceback.print_exc() - return RResponse.abort( - 400, f"Failed to initialize the predefined model. {e}" + # Get the model's class name + class_name = metadata.get("class_name") + + predefined = bool(int(metadata.get("predefined"))) + + # Save the model's code definition and initialize the model + if not predefined: # If the model is custom + # Get the model definition code and save it to a temporary file + code = request.form.get("code") + save_code(code, code_path) + # Initialize the model + try: + model = RModelWrapper.init_custom_model( + code_path, class_name, RServer.get_model_wrapper().device + ) + except Exception as e: + traceback.print_exc() + clear_model_temp_files(code_path, weight_path) + return RResponse.fail( + f"Failed to initialize the custom model. {e}", 400 + ) + elif predefined: # If the model is predefined + pretrained = bool(int(metadata.get("pretrained"))) + num_classes = int(metadata.get("num_classes")) + # TODO: Stop relying on the user to provide the number of classes + code = f"num_classes = {num_classes}" + save_code(code, code_path) + try: + model = RModelWrapper.init_predefined_model( + class_name, + pretrained, + num_classes, + RServer.get_model_wrapper().device, + ) + except Exception as e: + traceback.print_exc() + clear_model_temp_files(code_path, weight_path) + return RResponse.fail( + f"Failed to initialize the predefined model. {e}", 400 + ) + else: + return RResponse.fail( + "Invalid request. The model is neither custom nor predefined.", 400 ) - else: - return RResponse.abort(400, "The model is neither custom nor predefined.") - # Get the weight file and save it to a temporary location if it exists - if "weight_file" in request.files: - weight_file = request.files.get("weight_file") - try: + # Get the weight file and save it to a temporary location if it exists + if "weight_file" in request.files: weight_path = os.path.join( - RServer.get_server().base_dir, "generated", "models", f"{saving_id}.pth" + RServer.get_server().base_dir, + "generated", + "models", + "ckpt", + f"{saving_id}.pth", ) - weight_file.save(weight_path) - except Exception as e: - traceback.print_exc() - clear_model_temp_files(base_dir, saving_id) - return RResponse.abort(500, f"Failed to save the weight file. {e}") - - # Load the weights from the file - try: - model.load_state_dict(torch.load(weight_path)) - except Exception as e: - traceback.print_exc() - clear_model_temp_files(base_dir, saving_id) - return RResponse.abort(400, f"Failed to load the weights. {e}") - else: # If the weight file is not provided, save the current weights to a temporary location + weight_file = request.files.get("weight_file") + save_ckpt_weight(weight_file, weight_path) + # Load and validate the weights from the file + try: + load_ckpt_weight(model, weight_path) + except Exception as e: + traceback.print_exc() + clear_model_temp_files(code_path, weight_path) + return RResponse.fail(f"Failed to load the weights. {e}", 400) + + # Validate the model try: - weight_path = os.path.join( - RServer.get_server().base_dir, "generated", "models", f"{saving_id}.pth" + dummy_model_wrapper = DummyModelWrapper( + model, RServer.get_model_wrapper().device ) - torch.save(model.state_dict(), weight_path) + val_model(dummy_model_wrapper) except Exception as e: traceback.print_exc() - clear_model_temp_files(base_dir, saving_id) - return RResponse.abort(500, f"Failed to save the weight file. {e}") + clear_model_temp_files(code_path, weight_path) + return RResponse.fail(f"The model is invalid. {e}", 400) - # Validate the model - try: - dummy_model_wrapper = DummyModelWrapper( - model, RServer.get_model_wrapper().device + # Construct the metadata for saving + metadata_4_save = construct_metadata_4_save( + metadata, code_path, weight_path, model ) - val_model(RServer.get_data_manager(), dummy_model_wrapper) - except Exception as e: - traceback.print_exc() - clear_model_temp_files(base_dir, saving_id) - return RResponse.abort(400, f"The model is invalid. {e}") - # Update the metadata for saving - metadata_4_save["class_name"] = class_name - metadata_4_save["nickname"] = metadata.get("nickname") - metadata_4_save["description"] = ( - metadata.get("description") if metadata.get("description") else None - ) - metadata_4_save["tags"] = metadata.get("tags") if metadata.get("tags") else None - metadata_4_save["create_time"] = datetime.now() - metadata_4_save["code_path"] = code_path - metadata_4_save["weight_path"] = weight_path - metadata_4_save["epoch"] = 0 - metadata_4_save["train_accuracy"] = None - metadata_4_save["val_accuracy"] = None - metadata_4_save["test_accuracy"] = None - metadata_4_save["last_eval_on_dev_set"] = None - metadata_4_save["last_eval_on_test_set"] = None - - # Save the model's architecture to the metadata - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - print(model) - metadata_4_save["architecture"] = buffer.getvalue() - - # Save the model's metadata to the database - try: + # Save the model's metadata to the database RServer.get_model_wrapper().create_model(metadata_4_save) - except Exception as e: - traceback.print_exc() - return RResponse.abort(500, f"Failed to save the model. {e}") - # Set the current model to the newly uploaded model - try: + # Set the current model to the newly uploaded model RServer.get_model_wrapper().set_current_model(metadata.get("nickname")) + + return RResponse.ok("Success") except Exception as e: traceback.print_exc() - return RResponse.abort(500, f"Failed to set the current model. {e}") - - return RResponse.ok("Success") + clear_model_temp_files(code_path, weight_path) + return RResponse.abort(500, f"Unexpected error. {e}") @model_api.route("/model/list", methods=["GET"]) diff --git a/back-end/apis/predict.py b/back-end/apis/predict.py index fd478dfb..9c537979 100644 --- a/back-end/apis/predict.py +++ b/back-end/apis/predict.py @@ -16,6 +16,7 @@ predict_api = Blueprint("predict_api", __name__) + # Return prediction result @predict_api.route("/predict/") def predict(split): diff --git a/back-end/database/model.py b/back-end/database/model.py index 6db820bd..f7816acc 100644 --- a/back-end/database/model.py +++ b/back-end/database/model.py @@ -31,6 +31,8 @@ class Models(db.Model): id = db.Column(db.Integer, primary_key=True, autoincrement=True) class_name = db.Column(db.String) nickname = db.Column(db.String) + predefined = db.Column(db.Boolean) + pretrained = db.Column(db.Boolean) description = db.Column(db.String) architecture = db.Column(db.String) tags = db.relationship("Tags", secondary=model_tag_rel, backref="models") diff --git a/back-end/modules/visualize_module/flashtorch_/utils/__init__.py b/back-end/modules/visualize_module/flashtorch_/utils/__init__.py index aecd7bf6..5fb7d415 100644 --- a/back-end/modules/visualize_module/flashtorch_/utils/__init__.py +++ b/back-end/modules/visualize_module/flashtorch_/utils/__init__.py @@ -123,7 +123,6 @@ def denormalize(tensor): def standardize_and_clip( tensor, min_value=0.0, max_value=1.0, saturation=0.1, brightness=0.5 ): - """Standardizes and clips input tensor. Standardizes the input tensor (mean = 0.0, std = 1.0). The color saturation diff --git a/back-end/objects/RModelWrapper.py b/back-end/objects/RModelWrapper.py index 4aab0e47..a3022310 100644 --- a/back-end/objects/RModelWrapper.py +++ b/back-end/objects/RModelWrapper.py @@ -21,6 +21,7 @@ AVAILABLE_MODELS = list(MODEL_INPUT_SHAPE.keys()) + # TODO(Chonghan): Change this class to RModelManager later. class RModelWrapper: def __init__( @@ -38,7 +39,7 @@ def __init__( self.model_name = "" # TODO: Should initialize to None. Remove in the future. - self.model = RModelWrapper.init_pre_defined_model( + self.model = RModelWrapper.init_predefined_model( network_type, pretrained, num_classes, self.device ) self.num_classes = num_classes @@ -152,23 +153,42 @@ def get_model_by_name(name) -> Models: def load_model_by_name(self, model_name: str): model_meta_data = RModelWrapper.get_model_by_name(model_name) - # TODO: need a way to distinguish between predefined and custom model - if model_meta_data.class_name in AVAILABLE_MODELS: - model = RModelWrapper.init_pre_defined_model( + if model_meta_data.predefined: + # TODO: use the number of classes from RDataManager + file_path = model_meta_data.code_path + module_name = "variables_module" + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + num_classes = module.num_classes + + model = RModelWrapper.init_predefined_model( model_meta_data.class_name, - False, - self.num_classes, + model_meta_data.pretrained, + num_classes, self.device, ) else: model = RModelWrapper.init_custom_model( - model_meta_data.code_path, model_name, self.device + model_meta_data.code_path, model_meta_data.class_name, self.device ) - model.load_state_dict(torch.load(model_meta_data.weight_path)) + if model_meta_data.weight_path: + model.load_state_dict(torch.load(model_meta_data.weight_path)) return model, model_meta_data @staticmethod - def init_pre_defined_model(network_type, pretrained, num_classes, device): + def init_predefined_model(network_type, pretrained, num_classes, device): + # Check if the model is supported + if network_type not in AVAILABLE_MODELS: + raise Exception( + f"Requested model type {network_type} not supported. Please check." + ) + # If the model is pretrained, it should have the same number of classes as the ImageNet model + if pretrained and num_classes != IMAGENET_OUTPUT_SIZE: + raise Exception( + f"Pretrained model is supposed to have {IMAGENET_OUTPUT_SIZE} classes as output." + ) + if network_type == "resnet-18": model = torchvision.models.resnet18( pretrained=pretrained, num_classes=num_classes @@ -207,23 +227,16 @@ def init_pre_defined_model(network_type, pretrained, num_classes, device): model = torchvision.models.alexnet( pretrained=pretrained, num_classes=num_classes ) - else: - raise NotImplementedError( - "Requested model type not supported. Please check." - ) return model.to(device) @staticmethod def init_custom_model(code_path, name, device): - """Initialize the custom model by importing the class with the specified name in the file specified by code_path""" - try: - spec = importlib.util.spec_from_file_location("model_def", code_path) - model_def = importlib.util.module_from_spec(spec) - spec.loader.exec_module(model_def) - model = getattr(model_def, name)() - return model.to(device) - except Exception as e: - print("Failed to initialize the model.") - print(e) - raise e + """ + Initialize the custom model by importing the class with the specified name in the file specified by code_path + """ + spec = importlib.util.spec_from_file_location("model_def", code_path) + model_def = importlib.util.module_from_spec(spec) + spec.loader.exec_module(model_def) + model = getattr(model_def, name)() + return model.to(device) diff --git a/back-end/objects/RServer.py b/back-end/objects/RServer.py index 522f4655..b7c0c654 100644 --- a/back-end/objects/RServer.py +++ b/back-end/objects/RServer.py @@ -5,12 +5,10 @@ # Wrapper for flask server instance class RServer: - server_instance = None # Use createServer method instead! def __init__(self, configs, base_dir, dataset_dir, ckpt_dir, app, socket): - app.config["SWAGGER"] = { "title": "Robustar API", "uiversion": 3, diff --git a/back-end/utils/model_utils.py b/back-end/utils/model_utils.py index 77ae7169..cc5656f4 100644 --- a/back-end/utils/model_utils.py +++ b/back-end/utils/model_utils.py @@ -1,22 +1,11 @@ import os import torch -import torchvision -from objects.RDataManager import RDataManager +import io +import contextlib +import json +from objects.RServer import RServer from utils.predict import get_image_prediction - - -IMAGENET_OUTPUT_SIZE = 1000 - -PREDEFINED_MODELS = [ - "ResNet18", - "ResNet34", - "ResNet50", - "ResNet101", - "ResNet152", - "mobilenet-v2", - "ResNet18-32x32", - "AlexNet", -] +from datetime import datetime # Used for model validation @@ -26,78 +15,81 @@ def __init__(self, model, device): self.device = device -def init_predefined_model(name, pretrained, num_classes): - """Initialize the predefined model with the specified name""" - # Check if the model name is valid - if name not in PREDEFINED_MODELS: - raise Exception(f"Predefined model name {name} not recognized.") +def precheck_request_4_upload_model(request): + errors = [] - # If the model is pretrained, it should have the same number of classes as the ImageNet model - if pretrained and num_classes != IMAGENET_OUTPUT_SIZE: - raise Exception( - f"Pretrained model is supposed to have {IMAGENET_OUTPUT_SIZE} classes as output." - ) + # Check for the presence of metadata + metadata_str = request.form.get("metadata") + if not metadata_str: + errors.append("The model metadata is missing.") + return errors - if name == "ResNet18": - model = torchvision.models.resnet18( - pretrained=pretrained, num_classes=num_classes - ) - elif name == "ResNet34": - model = torchvision.models.resnet34( - pretrained=pretrained, num_classes=num_classes - ) - elif name == "ResNet50": - model = torchvision.models.resnet50( - pretrained=pretrained, num_classes=num_classes - ) - elif name == "ResNet101": - model = torchvision.models.resnet101( - pretrained=pretrained, num_classes=num_classes - ) - elif name == "ResNet152": - model = torchvision.models.resnet152( - pretrained=pretrained, num_classes=num_classes - ) - elif name == "mobilenet-v2": - model = torchvision.models.mobilenet_v2( - pretrained=pretrained, num_classes=num_classes - ) - elif name == "ResNet18-32x32": - model = torchvision.models.ResNet( - torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes - ) - model.conv1 = torch.nn.Conv2d( - 3, 64, kernel_size=3, stride=1, padding=1, bias=False - ) - model.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) - if pretrained: - raise Exception("Pretrained ResNet18-32x32 is not available.") - elif name == "AlexNet": - model = torchvision.models.alexnet( - pretrained=pretrained, num_classes=num_classes + metadata = json.loads(metadata_str) + + # Check for the presence of data + missing_keys = [] + required_keys = ["class_name", "nickname", "predefined", "pretrained"] + for key in required_keys: + if key not in metadata: + missing_keys.append(key) + if missing_keys: + errors.append( + f"The following metadata fields are missing: {', '.join(missing_keys)}" ) - return model + if metadata.get("predefined") == "0": + code = request.form.get("code") + if not code: + errors.append( + "Model definition code is missing but required when predefined is '0'." + ) + if metadata.get("pretrained") == "1": + weight_file = request.files.get("weight_file") + if weight_file: + errors.append("Weight file should not be specified when pretrained is '1'.") + + # Additional checks for metadata fields + if "class_name" in metadata and not isinstance(metadata["class_name"], str): + errors.append("class_name should be a string") + if "nickname" in metadata and not isinstance(metadata["nickname"], str): + errors.append("nickname should be a string") + if "predefined" in metadata: + if metadata["predefined"] not in ["0", "1"]: + errors.append("predefined should be a either '0' or '1'") + if "description" in metadata and not isinstance(metadata["description"], str): + errors.append("description should be a string") + if "pretrained" in metadata: + if metadata["pretrained"] not in ["0", "1"]: + errors.append("pretrained should be a either '0' or '1'") + if "num_classes" in metadata: + if ( + not isinstance(metadata["num_classes"], str) + or not metadata["num_classes"].isdigit() + ): + errors.append("num_classes should be a string representation of an integer") + if "tags" in metadata and not ( + isinstance(metadata["tags"], list) + and all(isinstance(tag, str) for tag in metadata["tags"]) + ): + errors.append("tags should be a list of strings") + + return errors -def clear_model_temp_files(base_dir, saving_id): +def clear_model_temp_files(code_path, weight_path): """Clear the temporary files associated with the model""" - code_path = os.path.join(base_dir, "generated", "models", f"{saving_id}.py") - weight_path = os.path.join(base_dir, "generated", "models", f"{saving_id}.pth") - try: + if code_path: if os.path.exists(code_path): os.remove(code_path) + if weight_path: if os.path.exists(weight_path): os.remove(weight_path) - except Exception as e: - print("Failed to clear the temporary files associated with the model.") - print(e) - raise e -def val_model(data_manager: RDataManager, model_wrapper: DummyModelWrapper): +def val_model(model_wrapper: DummyModelWrapper): """Validate the model by running the model against a small portion of the validation dataset""" # Get at most 10 samples from the validation dataset + data_manager = RServer.get_data_manager() dataset = data_manager.validationset samples = dataset.samples[:10] # Create a dummy model wrapper to pass to the predict function @@ -111,3 +103,61 @@ def val_model(data_manager: RDataManager, model_wrapper: DummyModelWrapper): data_manager.image_size, argmax=False, ) + + +def create_models_dir(): + # Check if the folder for saving models exists, if not, create it + models_dir = os.path.join(RServer.get_server().base_dir, "generated", "models") + if not os.path.exists(models_dir): + os.makedirs(models_dir) + if not os.path.exists(os.path.join(models_dir, "code")): + os.makedirs(os.path.join(models_dir, "code")) + if not os.path.exists(os.path.join(models_dir, "ckpt")): + os.makedirs(os.path.join(models_dir, "ckpt")) + + +def save_code(code, code_path): + with open(code_path, "w") as code_file: + code_file.write(code) + + +def save_ckpt_weight(weight_file, weight_path): + weight_file.save(weight_path) + + +def load_ckpt_weight(model, weight_path): + model.load_state_dict(torch.load(weight_path)) + + +def save_cur_weight(model, weight_path): + torch.save(model.state_dict(), weight_path) + + +def construct_metadata_4_save(metadata, code_path, weight_path, model): + # Construct the metadata for saving + metadata_4_save = { + "class_name": metadata.get("class_name"), + "nickname": metadata.get("nickname"), + "predefined": bool(int(metadata.get("predefined"))), + "pretrained": bool(int(metadata.get("pretrained"))), + "description": metadata.get("description"), + "tags": metadata.get("tags"), + "create_time": datetime.now(), + "code_path": code_path, + "weight_path": weight_path, + "epoch": 0, + "train_accuracy": None, + "val_accuracy": None, + "test_accuracy": None, + "last_trained": None, + "last_eval_on_dev_set": None, + "last_eval_on_test_set": None, + } + + # Save the model's architecture to the metadata + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + print(model) + metadata_4_save["architecture"] = buffer.getvalue() + + return metadata_4_save