diff --git a/src/api.py b/src/api.py index 36c257a..e1d06ec 100644 --- a/src/api.py +++ b/src/api.py @@ -1,24 +1,47 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image +""" +This module creates a FastAPI application for making predictions using the PyTorch model defined in main.py. +""" + +# PyTorch is an open source machine learning library based on the Torch library import torch + +# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints. +from fastapi import FastAPI, File, UploadFile + +# PIL is used for opening, manipulating, and saving many different image file formats +from PIL import Image + +# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision from torchvision import transforms -from main import Net # Importing Net class from main.py -# Load the model +# Importing Net class from main.py +from main import Net + +# 'model' represents the PyTorch model model = Net() model.load_state_dict(torch.load("mnist_model.pth")) model.eval() -# Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# 'transform' is a sequence of transformations applied to the images +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) +# 'app' represents the FastAPI application app = FastAPI() + @app.post("/predict/") async def predict(file: UploadFile = File(...)): + """ + This function predicts the digit in the uploaded image file. + + Parameters: + - file (UploadFile): The image file to predict. + + Returns: + - dict: A dictionary with the prediction. + """ image = Image.open(file.file).convert("L") image = transform(image) image = image.unsqueeze(0) # Add batch dimension diff --git a/src/main.py b/src/main.py index 243a31e..d4beac6 100644 --- a/src/main.py +++ b/src/main.py @@ -1,28 +1,56 @@ -from PIL import Image +""" +This module is used to load and preprocess the MNIST dataset and define a PyTorch model. +""" + +# numpy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays +import numpy as np + +# PyTorch is an open source machine learning library based on the Torch library import torch + +# torch.nn is a sublibrary of PyTorch, provides classes to build neural networks import torch.nn as nn + +# torch.optim is a package implementing various optimization algorithms import torch.optim as optim -from torchvision import datasets, transforms + +# DataLoader combines a dataset and a sampler, and provides an iterable over the given dataset from torch.utils.data import DataLoader -import numpy as np + +# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision +from torchvision import datasets, transforms + +# PIL is used for opening, manipulating, and saving many different image file formats + # Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +# 'transform' is a sequence of transformations applied to the images in the dataset +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +# 'trainset' represents the MNIST dataset +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) +# 'trainloader' is a data loader for the MNIST dataset trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + # Step 2: Define the PyTorch Model class Net(nn.Module): + """ + This class defines a simple feed-forward neural network for the MNIST dataset. + + Methods: + - __init__: Initializes the neural network with three fully connected layers. + - forward: Defines the forward pass of the neural network. + """ + def __init__(self): super().__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) - + def forward(self, x): x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) @@ -30,6 +58,7 @@ def forward(self, x): x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) + # Step 3: Train the Model model = Net() optimizer = optim.SGD(model.parameters(), lr=0.01) @@ -45,4 +74,4 @@ def forward(self, x): loss.backward() optimizer.step() -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +torch.save(model.state_dict(), "mnist_model.pth")