Skip to content

Commit c21b49b

Browse files
committed
udpate integrators and utils tests
1 parent 3b8ddbf commit c21b49b

File tree

2 files changed

+179
-26
lines changed

2 files changed

+179
-26
lines changed

MCintegration/integrators_test.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import numpy as np
66
from integrators import Integrator, MonteCarlo, MarkovChainMonteCarlo
77
from integrators import get_ip, get_open_port, setup
8+
from integrators import gaussian, random_walk
89

910
from maps import Configuration
11+
from base import EPSILON
1012

1113

1214
class TestIntegrators(unittest.TestCase):
@@ -34,6 +36,50 @@ def test_setup(self, mock_set_device, mock_init_process_group):
3436
mock_init_process_group.assert_called_once_with(backend="gloo")
3537
mock_set_device.assert_called_once_with(0)
3638

39+
def test_random_walk(self):
40+
# Test random_walk function with default parameters
41+
dim = 3
42+
device = "cpu"
43+
dtype = torch.float32
44+
u = torch.rand(dim, device=device, dtype=dtype)
45+
step_size = 0.2
46+
47+
new_u = random_walk(dim, device, dtype, u, step_size=step_size)
48+
49+
# Validate output shape and range
50+
self.assertEqual(new_u.shape, u.shape)
51+
self.assertTrue(torch.all(new_u >= 0) and torch.all(new_u <= 1))
52+
53+
# Test with custom step size
54+
custom_step_size = 0.5
55+
new_u_custom = random_walk(dim, device, dtype, u, step_size=custom_step_size)
56+
self.assertEqual(new_u_custom.shape, u.shape)
57+
self.assertTrue(torch.all(new_u_custom >= 0) and torch.all(new_u_custom <= 1))
58+
59+
def test_gaussian(self):
60+
# Test gaussian function with default parameters
61+
dim = 3
62+
device = "cpu"
63+
dtype = torch.float32
64+
u = torch.rand(dim, device=device, dtype=dtype)
65+
66+
mean = torch.zeros_like(u)
67+
std = torch.ones_like(u)
68+
69+
new_u = gaussian(dim, device, dtype, u, mean=mean, std=std)
70+
71+
# Validate output shape
72+
self.assertEqual(new_u.shape, u.shape)
73+
74+
# Test with custom mean and std
75+
custom_mean = torch.full_like(u, 0.5)
76+
custom_std = torch.full_like(u, 0.1)
77+
new_u_custom = gaussian(dim, device, dtype, u, mean=custom_mean, std=custom_std)
78+
79+
self.assertEqual(new_u_custom.shape, u.shape)
80+
self.assertTrue(torch.all(new_u_custom > custom_mean - 3 * custom_std))
81+
self.assertTrue(torch.all(new_u_custom < custom_mean + 3 * custom_std))
82+
3783

3884
class TestIntegrator(unittest.TestCase):
3985
def setUp(self):
@@ -81,6 +127,18 @@ def test_initialization_with_maps(self, mock_linear_map, mock_composite_map):
81127
self.assertTrue(hasattr(integrator.maps, "device"))
82128
self.assertTrue(hasattr(integrator.maps, "dtype"))
83129

130+
integrator = Integrator(
131+
bounds=self.bounds,
132+
f=self.f,
133+
maps=mock_map,
134+
batch_size=self.batch_size,
135+
device="cpu",
136+
)
137+
138+
# Assertions
139+
self.assertEqual(integrator.device, "cpu")
140+
self.assertTrue(hasattr(integrator.maps, "device"))
141+
84142
def test_bounds_conversion(self):
85143
# Test various input types
86144
test_cases = [
@@ -223,6 +281,27 @@ def test_batch_size_handling(self):
223281
# Should not raise warning
224282
self.mc(neval=neval, nblock=nblock)
225283

284+
def test_block_size_warning(self):
285+
mc = MonteCarlo(bounds=self.bounds, f=self.simple_integral, batch_size=1000)
286+
with self.assertWarns(UserWarning):
287+
mc(neval=500, nblock=10) # neval too small for nblock
288+
289+
def test_varying_nblock(self):
290+
test_cases = [
291+
(10000, 10), # Standard case
292+
(10000, 1), # Single block
293+
(10000, 100), # Many blocks
294+
]
295+
296+
for neval, nblock in test_cases:
297+
with self.subTest(neval=neval, nblock=nblock):
298+
result = self.mc(neval=neval, nblock=nblock)
299+
if hasattr(result, "mean"):
300+
value = result.mean
301+
else:
302+
value = result
303+
self.assertAlmostEqual(float(value), 1.0, delta=0.1)
304+
226305

227306
class TestMarkovChainMonteCarlo(unittest.TestCase):
228307
def setUp(self):
@@ -293,6 +372,41 @@ def test_burnin_effect(self):
293372
value = result
294373
self.assertAlmostEqual(float(value), 1.0, delta=tolerance)
295374

375+
def test_sample_acceptance(self):
376+
config = Configuration(
377+
self.mcmc.batch_size,
378+
self.mcmc.dim,
379+
self.mcmc.f_dim,
380+
self.mcmc.device,
381+
self.mcmc.dtype,
382+
)
383+
config.u, config.detJ = self.mcmc.q0.sample_with_detJ(self.mcmc.batch_size)
384+
config.x, detj = self.mcmc.maps.forward_with_detJ(config.u)
385+
config.detJ *= detj
386+
config.weight = torch.rand(self.mcmc.batch_size, device=self.mcmc.device)
387+
388+
self.mcmc.sample(config, nsteps=1, mix_rate=0.5)
389+
390+
# Validate acceptance logic
391+
self.assertTrue(torch.all(config.weight >= EPSILON))
392+
self.assertEqual(config.u.shape, config.x.shape)
393+
394+
def test_varying_mix_rate(self):
395+
test_cases = [
396+
(0.1, 0.2), # Low mix rate
397+
(0.5, 0.1), # Medium mix rate
398+
(0.9, 0.05), # High mix rate
399+
]
400+
401+
for mix_rate, tolerance in test_cases:
402+
with self.subTest(mix_rate=mix_rate):
403+
result = self.mcmc(neval=50000, mix_rate=mix_rate, nblock=10)
404+
if hasattr(result, "mean"):
405+
value = result.mean
406+
else:
407+
value = result
408+
self.assertAlmostEqual(float(value), 1.0, delta=tolerance)
409+
296410

297411
class TestDistributedFunctionality(unittest.TestCase):
298412
@unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")

MCintegration/utils_test.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,39 @@ def test_converged(self):
142142
self.weighted_ravg.add(gvar.gvar(1.0, 0.01))
143143
self.assertTrue(self.weighted_ravg.converged(0.1, 0.1))
144144

145-
# def test_multiplication(self):
146-
# ravg1 = RAvg(weighted=True)
147-
# ravg1.update(2.0, 0.1)
148-
# ravg2 = RAvg(weighted=True)
149-
# ravg2.update(3.0, 0.1)
150-
# result = ravg1 * ravg2
151-
# self.assertAlmostEqual(result.mean, 6.0)
152-
153-
# def test_division(self):
154-
# ravg1 = RAvg(weighted=True)
155-
# ravg1.update(6.0, 0.1)
156-
# ravg2 = RAvg(weighted=True)
157-
# ravg2.update(3.0, 0.1)
158-
# result = ravg1 / ravg2
159-
# self.assertAlmostEqual(result.mean, 2.0)
145+
def test_reduce_ex_serialization(self):
146+
ravg = RAvg(weighted=True)
147+
ravg.add(gvar.gvar(1.0, 0.1))
148+
reduced = ravg.__reduce_ex__(2)
149+
restored = reduced[0](*reduced[1])
150+
self.assertEqual(restored.mean, ravg.mean)
151+
self.assertEqual(restored.sdev, ravg.sdev)
152+
153+
def test_summary_output(self):
154+
ravg = RAvg(weighted=True)
155+
ravg.add(gvar.gvar(1.0, 0.1))
156+
summary = ravg.summary()
157+
self.assertIn("itn", summary)
158+
self.assertIn("integral", summary)
159+
self.assertIn("wgt average", summary)
160+
161+
def test_converged_criteria(self):
162+
ravg = RAvg(weighted=True)
163+
ravg.add(gvar.gvar(1.0, 0.1))
164+
self.assertTrue(ravg.converged(0.1, 0.1))
165+
self.assertFalse(ravg.converged(0.001, 0.001))
166+
167+
def test_multiplication_with_another_ravg(self):
168+
ravg1 = RAvg(weighted=True)
169+
ravg1.update(2.0, 0.1)
170+
ravg2 = RAvg(weighted=True)
171+
ravg2.update(3.0, 0.1)
172+
173+
result = ravg1 * ravg2
174+
self.assertAlmostEqual(result.mean, 6.0)
175+
sdev = (0.1 / 2**2 + 0.1 / 3**2) ** 0.5 * 6.0
176+
self.assertAlmostEqual(result.sdev, sdev)
177+
160178
def test_multiplication(self):
161179
ravg1 = RAvg(weighted=True)
162180
# Test multiplication by another RAvg object
@@ -198,6 +216,17 @@ def test_multiplication(self):
198216
np.allclose([r.sdev for r in result], [2.0 * ravg1.sdev, 3.0 * ravg1.sdev])
199217
)
200218

219+
def test_division_with_another_ravg(self):
220+
ravg1 = RAvg(weighted=True)
221+
ravg1.update(6.0, 0.1)
222+
ravg2 = RAvg(weighted=True)
223+
ravg2.update(3.0, 0.1)
224+
225+
result = ravg1 / ravg2
226+
self.assertAlmostEqual(result.mean, 2.0)
227+
sdev = (0.1 / 6.0**2 + 0.1 / 3.0**2) ** 0.5 * 2.0
228+
self.assertAlmostEqual(result.sdev, sdev)
229+
201230
def test_division(self):
202231
ravg1 = RAvg(weighted=True)
203232
ravg1.update(6.0, 0.1)
@@ -254,25 +283,35 @@ def test_set_seed_cpu(self):
254283
# Test set_seed on a CPU-only environment
255284
set_seed(42)
256285
self.assertEqual(torch.initial_seed(), 42)
257-
258-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
259-
def test_set_seed_cuda(self):
260-
# Test set_seed on a CUDA-enabled environment
286+
u1 = torch.rand(10)
261287
set_seed(42)
262-
self.assertEqual(torch.initial_seed(), 42)
263-
self.assertEqual(torch.cuda.initial_seed(), 42)
288+
u2 = torch.rand(10)
289+
self.assertTrue(torch.all(u1 == u2))
290+
291+
# @unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
292+
# def test_set_seed_cuda(self):
293+
# # Test set_seed on a CUDA-enabled environment
294+
# set_seed(42)
295+
# self.assertEqual(torch.initial_seed(), 42)
296+
# self.assertEqual(torch.cuda.initial_seed(), 42)
264297

265298
@unittest.skipIf(torch.cuda.is_available(), "CUDA is available")
266299
def test_get_device_cpu(self):
267300
# Test get_device when CUDA is not available
268301
device = get_device()
269302
self.assertEqual(device, torch.device("cpu"))
270303

271-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
272-
def test_get_device_cuda(self):
273-
# Test get_device when CUDA is available
274-
device = get_device()
275-
self.assertEqual(device, torch.cuda.current_device())
304+
# @unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
305+
# def test_get_device_cuda(self):
306+
# # Test get_device when CUDA is available
307+
# device = get_device()
308+
# self.assertEqual(device, torch.cuda.current_device())
309+
310+
def test_get_device_cuda_inactive(self):
311+
if not torch.cuda.is_available():
312+
torch.cuda.set_device(-1) # Simulate inactive CUDA
313+
device = get_device()
314+
self.assertEqual(device, torch.device("cpu"))
276315

277316

278317
if __name__ == "__main__":

0 commit comments

Comments
 (0)