From 15beb04e9da043b37ef92f0d9048ee3744ee5ede Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Thu, 12 Oct 2023 22:07:03 +0000 Subject: [PATCH 1/2] feat: add tests for main.py --- tests/test_main.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/test_main.py diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..00b3c2e --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,30 @@ +import pytest +from pytest_mock import MockerFixture +from torchvision import datasets +from torch.utils.data import DataLoader +from src import main + +def test_data_loading_and_preprocessing(mocker: MockerFixture): + """Test the data loading and preprocessing steps.""" + mock_mnist = mocker.patch.object(datasets, 'MNIST') + mock_dataloader = mocker.patch.object(DataLoader, '__init__') + + main.load_and_preprocess_data() + + mock_mnist.assert_called_once_with('.', download=True, train=True, transform=main.transform) + mock_dataloader.assert_called_once_with(mock_mnist.return_value, batch_size=64, shuffle=True) + +def test_model_definition(): + """Test the model definition.""" + model = main.Net() + + assert isinstance(model, main.Net) + assert isinstance(model.fc1, nn.Linear) + assert model.fc1.in_features == 784 + assert model.fc1.out_features == 128 + assert isinstance(model.fc2, nn.Linear) + assert model.fc2.in_features == 128 + assert model.fc2.out_features == 64 + assert isinstance(model.fc3, nn.Linear) + assert model.fc3.in_features == 64 + assert model.fc3.out_features == 10 From e526e084ee118ba149bc505f8de9b5d250e9e46c Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Thu, 12 Oct 2023 22:09:17 +0000 Subject: [PATCH 2/2] feat: Updated tests/test_main.py --- tests/test_main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_main.py b/tests/test_main.py index 00b3c2e..533dfa5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,5 @@ import pytest +import torch.nn as nn from pytest_mock import MockerFixture from torchvision import datasets from torch.utils.data import DataLoader