Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #117

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,47 @@
from fastapi import FastAPI, UploadFile, File
from PIL import Image
"""
This module creates a FastAPI application for making predictions using the PyTorch model defined in main.py.
"""

# PyTorch is an open source machine learning library based on the Torch library
import torch

# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints.
from fastapi import FastAPI, File, UploadFile

# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image

# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision
from torchvision import transforms
from main import Net # Importing Net class from main.py

# Load the model
# Importing Net class from main.py
from main import Net

# 'model' represents the PyTorch model
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 'transform' is a sequence of transformations applied to the images
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

# 'app' represents the FastAPI application
app = FastAPI()


@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
This function predicts the digit in the uploaded image file.

Parameters:
- file (UploadFile): The image file to predict.

Returns:
- dict: A dictionary with the prediction.
"""
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
Expand Down
49 changes: 39 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,64 @@
from PIL import Image
"""
This module is used to load and preprocess the MNIST dataset and define a PyTorch model.
"""

# numpy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays
import numpy as np

# PyTorch is an open source machine learning library based on the Torch library
import torch

# torch.nn is a sublibrary of PyTorch, provides classes to build neural networks
import torch.nn as nn

# torch.optim is a package implementing various optimization algorithms
import torch.optim as optim
from torchvision import datasets, transforms

# DataLoader combines a dataset and a sampler, and provides an iterable over the given dataset
from torch.utils.data import DataLoader
import numpy as np

# torchvision is a library for PyTorch that provides access to popular datasets, model architectures, and image transformations for computer vision
from torchvision import datasets, transforms

# PIL is used for opening, manipulating, and saving many different image file formats


# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 'transform' is a sequence of transformations applied to the images in the dataset
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
# 'trainset' represents the MNIST dataset
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
# 'trainloader' is a data loader for the MNIST dataset
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)


# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This class defines a simple feed-forward neural network for the MNIST dataset.

Methods:
- __init__: Initializes the neural network with three fully connected layers.
- forward: Defines the forward pass of the neural network.
"""

def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)


# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
Expand All @@ -45,4 +74,4 @@ def forward(self, x):
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
torch.save(model.state_dict(), "mnist_model.pth")