Skip to content

Commit

Permalink
Merge pull request #196 from wbenoit26/seed-tests
Browse files Browse the repository at this point in the history
Set random seed for unit tests
  • Loading branch information
wbenoit26 authored Feb 6, 2025
2 parents 16b0248 + e7743f5 commit 097361d
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 212 deletions.
77 changes: 44 additions & 33 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import random

import numpy as np
import pytest
import torch
from scipy.special import erfinv
from torch.distributions import Uniform


# If a fixture is doing anything random,
# it should take this function as an argument
@pytest.fixture(autouse=True)
def seed_everything():
seed = 101589
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


@pytest.fixture
def compare_against_numpy():
"""
Expand Down Expand Up @@ -81,9 +93,8 @@ def validate(whitened, highpass, lowpass, sample_rate, df):
return validate


# number of samples to draw from
# the distributions for testing
N_SAMPLES = 1000
# A num_samples fixture should be defined for any
# test that wants to use these fixtures


@pytest.fixture(params=[256, 1024, 2048])
Expand All @@ -92,90 +103,90 @@ def sample_rate(request):


@pytest.fixture()
def chirp_mass(request):
def chirp_mass(num_samples, seed_everything):
dist = Uniform(5, 100)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def mass_ratio():
def mass_ratio(num_samples, seed_everything):
dist = Uniform(0.125, 0.99)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def a_1(request):
def a_1(num_samples, seed_everything):
dist = Uniform(0, 0.90)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def a_2(request):
def a_2(num_samples, seed_everything):
dist = Uniform(0, 0.90)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def tilt_1(request):
def tilt_1(num_samples, seed_everything):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def tilt_2(request):
def tilt_2(num_samples, seed_everything):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def phi_12(request):
def phi_12(num_samples, seed_everything):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def phi_jl(request):
def phi_jl(num_samples, seed_everything):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def distance(request):
def distance(num_samples, seed_everything):
dist = Uniform(100, 3000)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def distance_far(request):
def distance_far(num_samples, seed_everything):
dist = Uniform(500, 3000)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def distance_close(request):
def distance_close(num_samples, seed_everything):
dist = Uniform(100, 500)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def theta_jn(request):
def theta_jn(num_samples, seed_everything):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def phase(request):
def phase(num_samples, seed_everything):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def chi1(request):
def chi1(num_samples, seed_everything):
dist = Uniform(-0.999, 0.999)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def chi2(request):
def chi2(num_samples, seed_everything):
dist = Uniform(-0.999, 0.999)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))
69 changes: 2 additions & 67 deletions tests/transforms/test_iirfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from astropy import units as u
from scipy.signal import filtfilt, iirfilter
from torch.distributions import Uniform

from ml4gw.constants import MSUN
from ml4gw.transforms.iirfilter import IIRFilter
Expand Down Expand Up @@ -124,73 +123,9 @@ def test_filters_synthetic_signal(sample_rate, order):
)


N_SAMPLES = 1


@pytest.fixture()
def chirp_mass(request):
dist = Uniform(5, 100)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def mass_ratio():
dist = Uniform(0.125, 0.99)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def a_1(request):
dist = Uniform(0, 0.90)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def a_2(request):
dist = Uniform(0, 0.90)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def tilt_1(request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def tilt_2(request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phi_12(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phi_jl(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def distance(request):
dist = Uniform(100, 3000)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def theta_jn(request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phase(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))
def num_samples():
return 1


@pytest.fixture(params=[20, 40])
Expand Down
Loading

0 comments on commit 097361d

Please sign in to comment.