Skip to content

Commit 31eb32b

Browse files
Merge pull request SebastianLoef#4 from carlthome/add-pytest
Add unit test workflow
2 parents 3b54848 + d6b65ea commit 31eb32b

14 files changed

+137
-75
lines changed

.github/workflows/pytest.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
on: push
2+
3+
jobs:
4+
test:
5+
runs-on: ubuntu-latest
6+
steps:
7+
- uses: actions/checkout@v3
8+
9+
- name: Set up Python
10+
uses: actions/setup-python@v4
11+
with:
12+
python-version: "3.x"
13+
cache: pip
14+
15+
- name: Install dependencies
16+
run: make requirements
17+
18+
- name: Install test runner
19+
run: pip install pytest pytest-cov
20+
21+
- name: Run unit tests
22+
run: pytest --cov=src

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
[tool.isort]
22
profile = "black"
3+
4+
[tool.autoflake]
5+
remove_all_unused_imports = true

src/architectures.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
import torch.nn as nn
32
from torchvision.models import ResNet50_Weights, resnet50
43

src/data/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from src.data.freemusicarchive import FreeMusicArchive
2+
from src.data.gtzan import GTZAN
3+
from src.data.magnatagatune import MagnaTagATune
4+
from src.data.millionsongdataset import MillionSongDataset
5+
from src.data.nsynth import NSynthInstrument, NSynthPitch
6+
7+
DATASETS = {
8+
"mtat": MagnaTagATune,
9+
"fma": FreeMusicArchive,
10+
"gtzan": GTZAN,
11+
"msd": MillionSongDataset,
12+
"nsynth_instrument": NSynthInstrument,
13+
"nsynth_pitch": NSynthPitch,
14+
}

src/data/test_dataset.py renamed to src/data/clips_dataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from typing import Tuple
22

33
import torch
4-
import torch.nn as nn
54
from torch.utils.data import Dataset
65

7-
from transforms import MelSpectrogram
6+
from src.transforms import MelSpectrogram
87

98

10-
class TestDataset(Dataset):
9+
class ClipsDataset(Dataset):
1110
def __init__(self, args, dataset: Dataset) -> None:
1211
super().__init__()
1312
self.dataset = dataset
@@ -33,7 +32,7 @@ def __len__(self) -> int:
3332
from torch.utils.data import DataLoader
3433

3534
dataset = MagnaTagATune("test")
36-
tdataset = TestDataset(dataset)
35+
tdataset = ClipsDataset(dataset)
3736
loader = DataLoader(tdataset, batch_size=1, shuffle=False)
3837

3938
for batch, label in loader:

src/data/encoded_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from torch import Tensor
77
from torchvision.transforms import Compose
88

9-
from transforms import MelSpectrogram, RandomResizedCrop
10-
from utils import generate_encodings, get_dataset
9+
from src.transforms import MelSpectrogram, RandomResizedCrop
10+
from src.utils import generate_encodings
1111

1212

1313
class EncodedDataset(nn.Module):
@@ -31,7 +31,7 @@ def get_integral_dataset(
3131
MelSpectrogram(backbone_args),
3232
]
3333
)
34-
dataset = get_dataset(args.dataset)(subset=subset, transforms=transforms)
34+
dataset = DATASETS[args.dataset](subset=subset, transforms=transforms)
3535
self.MULTILABEL = dataset.MULTILABEL
3636
self.NUM_LABELS = dataset.NUM_LABELS
3737
return dataset

src/data/nsynth.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import shutil
44
from typing import Tuple
55

6-
import pandas as pd
76
import requests
87
import torch
98
import torchaudio

src/evaluate.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
import argparse
22

3-
import lightning as L
4-
import numpy as np
53
import yaml
64
from sklearn import metrics
7-
from sklearn.linear_model import LinearRegression, LogisticRegression, SGDRegressor
5+
from sklearn.linear_model import SGDRegressor
86
from sklearn.multiclass import OneVsRestClassifier
9-
from sklearn.multioutput import ClassifierChain
10-
from sklearn.tree import DecisionTreeRegressor
117

12-
from architectures import resnet
13-
from modules.VICReg import VICReg
14-
from transforms import MelSpectrogram
15-
from utils import (
8+
from src.architectures import resnet
9+
from src.data import DATASETS
10+
from src.modules.VICReg import VICReg
11+
from src.transforms import MelSpectrogram
12+
from src.utils import (
1613
generate_encodings,
1714
get_best_metric_checkpoint_path,
18-
get_dataset,
1915
get_epoch_checkpoint_path,
2016
load_parameters,
2117
)
@@ -91,11 +87,9 @@ def main(args):
9187
# datasets
9288
############################
9389
transforms = MelSpectrogram(backbone_args)
94-
train_dataset = get_dataset(args.train_dataset)(
95-
subset="train", transforms=transforms
96-
)
97-
val_dataset = get_dataset(args.val_dataset)(subset="valid", transforms=transforms)
98-
test_dataset = get_dataset(args.test_dataset)(subset="test", transforms=transforms)
90+
train_dataset = DATASETS[args.dataset](subset="train", transforms=transforms)
91+
val_dataset = DATASETS[args.val_dataset](subset="valid", transforms=transforms)
92+
test_dataset = DATASETS[args.test_dataset](subset="test", transforms=transforms)
9993

10094
############################
10195
# model

src/modules/VICReg.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,22 @@
88
from torch import Tensor
99
from torch.utils.data import DataLoader
1010

11-
from architectures import mlp
12-
from optimizers import LARS, adjust_learning_rate, include_bias_and_norm
13-
from transforms import AudioSplit
14-
from utils import get_dataset, off_diagonal
11+
from src.architectures import mlp
12+
from src.optimizers import LARS, adjust_learning_rate, include_bias_and_norm
13+
from src.transforms import AudioSplit
14+
from src.utils import off_diagonal
1515

1616

1717
class VICReg(L.LightningModule):
18-
def __init__(self, args, backbone):
18+
def __init__(self, args, dataset, backbone):
1919
super().__init__()
2020
self.args = args
2121
self.num_features = int(args.projector.split("-")[-1])
2222
self.backbone = backbone
2323
self.projector = mlp(args.projector)
2424
self.val_outputs = []
2525
self.train_outputs = []
26+
self.dataset = dataset
2627

2728
def internal_forward(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
2829
x = self.projector(self.backbone(x))
@@ -103,9 +104,10 @@ def configure_optimizers(self):
103104
return optimizer
104105

105106
def train_dataloader(self) -> TRAIN_DATALOADERS:
106-
dataset = get_dataset(self.args.dataset)
107107
return DataLoader(
108-
dataset("train", transforms=AudioSplit(self.args), mixing=self.args.mixing),
108+
self.dataset(
109+
"train", transforms=AudioSplit(self.args), mixing=self.args.mixing
110+
),
109111
batch_size=self.args.batch_size,
110112
shuffle=True,
111113
num_workers=self.args.num_workers,

src/train_backbone.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from lightning.pytorch.callbacks import ModelCheckpoint
66
from lightning.pytorch.loggers import WandbLogger
77

8-
from architectures import resnet
9-
from modules.VICReg import VICReg
10-
from utils import get_model_name, get_model_number, save_parameters
8+
from src.architectures import resnet
9+
from src.data import DATASETS
10+
from src.modules.VICReg import VICReg
11+
from src.utils import get_model_name, get_model_number, save_parameters
1112

1213

1314
def get_arguments():
@@ -45,7 +46,8 @@ def main(args):
4546
# model
4647
############################
4748
backbone = resnet(args.pretrained)
48-
model = VICReg(args, backbone)
49+
dataset = DATASETS[args.dataset]
50+
model = VICReg(args, dataset, backbone)
4951
checkpoint = ModelCheckpoint(
5052
dirpath=f"data/models/{name}",
5153
filename="vicreg-{epoch:02d}",

src/train_head.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from torchaudio_augmentations import RandomResizedCrop
99
from torchvision.transforms import Compose
1010

11-
from architectures import resnet
12-
from data.test_dataset import TestDataset
13-
from modules.Classifier import Classifier
14-
from modules.VICReg import VICReg
15-
from transforms import MelSpectrogram
16-
from utils import (
11+
from src.architectures import resnet
12+
from src.data import DATASETS
13+
from src.data.clips_dataset import ClipsDataset
14+
from src.modules.Classifier import Classifier
15+
from src.modules.VICReg import VICReg
16+
from src.transforms import MelSpectrogram
17+
from src.utils import (
1718
class_balanced_sampler,
1819
get_best_metric_checkpoint_path,
19-
get_dataset,
2020
get_epoch_checkpoint_path,
2121
get_model_number,
2222
load_parameters,
@@ -65,11 +65,11 @@ def main(args):
6565
############################
6666
# dataset
6767
############################
68-
dataset = get_dataset(args.dataset)
68+
dataset = DATASETS[backbone_args.dataset]
6969
train_dataset = dataset(subset="train", transforms=transforms)
7070
val_dataset = dataset(subset="valid", transforms=transforms)
7171
test_dataset = dataset(subset="test", transforms=None)
72-
test_dataset = TestDataset(backbone_args, test_dataset)
72+
test_dataset = ClipsDataset(backbone_args, test_dataset)
7373
if args.class_balanced:
7474
sampler = class_balanced_sampler(train_dataset)
7575
shuffle = False
@@ -105,8 +105,9 @@ def main(args):
105105
# model
106106
############################
107107
print(backbone_path)
108+
dataset = DATASETS[backbone_args.dataset]
108109
backbone_module = VICReg.load_from_checkpoint(
109-
backbone_path, args=backbone_args, backbone=resnet()
110+
backbone_path, args=backbone_args, dataset=dataset, backbone=resnet()
110111
)
111112
backbone = backbone_module.backbone.cpu()
112113
model = Classifier(args, MULTILABELS, NUM_LABELS, backbone)

src/utils.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,6 @@
99
import torch
1010
from tqdm import tqdm
1111

12-
from data.freemusicarchive import FreeMusicArchive
13-
from data.gtzan import GTZAN
14-
from data.magnatagatune import MagnaTagATune
15-
from data.millionsongdataset import MillionSongDataset
16-
from data.nsynth import NSynthInstrument, NSynthPitch
17-
18-
# import wget
19-
2012

2113
def generate_encodings(args, module, dataset, subset, normalize=False):
2214
path = f"data/models/{args.name}/{args.dataset}"
@@ -82,28 +74,6 @@ def get_epoch_checkpoint_path(name: str, epoch: int = 0) -> str:
8274
return d[idx]
8375

8476

85-
def get_dataset(name: str):
86-
if name == "mtat":
87-
print("Using MagnaTagATune dataset")
88-
return MagnaTagATune
89-
elif name == "fma":
90-
print("Using FreeMusicArchive dataset")
91-
return FreeMusicArchive
92-
elif name == "gtzan":
93-
print("Using GTZAN dataset")
94-
return GTZAN
95-
elif name == "msd":
96-
print("Using MillionSongDataset dataset")
97-
return MillionSongDataset
98-
elif "nsynth" in name:
99-
if "instrument" in name:
100-
return NSynthInstrument
101-
elif "pitch" in name:
102-
return NSynthPitch
103-
104-
raise NotImplementedError
105-
106-
10777
def save_parameters(args, name):
10878
if not os.path.exists(f"data/models/{name}"):
10979
os.makedirs(f"data/models/{name}")

tests/__init__.py

Whitespace-only changes.

tests/test_train.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import argparse
2+
3+
import lightning
4+
import torch
5+
from torch.utils.data import Dataset
6+
7+
from src.architectures import resnet
8+
from src.modules.VICReg import VICReg
9+
10+
11+
class RandomDataset(Dataset):
12+
def __init__(self, subset, mixing, transforms) -> None:
13+
self.transforms = transforms
14+
15+
def __getitem__(self, index: int):
16+
x = torch.rand(3, 128, 128, requires_grad=True)
17+
y = torch.rand(3, 128, 128, requires_grad=True)
18+
z = torch.rand(1, requires_grad=True)
19+
return (x, y), z
20+
21+
def __len__(self) -> int:
22+
return 10
23+
24+
25+
def test_train():
26+
args = argparse.Namespace(
27+
batch_size=2,
28+
cov_coeff=1.0,
29+
devices=1,
30+
epochs=1,
31+
f_max=1.0,
32+
f_min=1.0,
33+
hop_length=1,
34+
mixing=False,
35+
n_fft=1,
36+
n_samples=2,
37+
normalize=False,
38+
num_workers=0,
39+
prefetch_factor=None,
40+
projector="2048-2-2048",
41+
sample_rate=3,
42+
sim_coeff=1.0,
43+
std_coeff=1.0,
44+
strategy="auto",
45+
weight_decay=1e-9,
46+
win_length=1,
47+
base_lr=1e-3,
48+
)
49+
50+
backbone = resnet(pretrained=False)
51+
dataset = RandomDataset
52+
model = VICReg(args=args, dataset=dataset, backbone=backbone)
53+
54+
trainer = lightning.Trainer(max_epochs=1)
55+
trainer.fit(model)
56+
57+
assert trainer.state.finished

0 commit comments

Comments
 (0)