|
5 | 5 | import numpy as np
|
6 | 6 | from integrators import Integrator, MonteCarlo, MarkovChainMonteCarlo
|
7 | 7 | from integrators import get_ip, get_open_port, setup
|
| 8 | +from integrators import gaussian, random_walk |
8 | 9 |
|
9 | 10 | from maps import Configuration
|
| 11 | +from base import EPSILON |
10 | 12 |
|
11 | 13 |
|
12 | 14 | class TestIntegrators(unittest.TestCase):
|
@@ -34,6 +36,50 @@ def test_setup(self, mock_set_device, mock_init_process_group):
|
34 | 36 | mock_init_process_group.assert_called_once_with(backend="gloo")
|
35 | 37 | mock_set_device.assert_called_once_with(0)
|
36 | 38 |
|
| 39 | + def test_random_walk(self): |
| 40 | + # Test random_walk function with default parameters |
| 41 | + dim = 3 |
| 42 | + device = "cpu" |
| 43 | + dtype = torch.float32 |
| 44 | + u = torch.rand(dim, device=device, dtype=dtype) |
| 45 | + step_size = 0.2 |
| 46 | + |
| 47 | + new_u = random_walk(dim, device, dtype, u, step_size=step_size) |
| 48 | + |
| 49 | + # Validate output shape and range |
| 50 | + self.assertEqual(new_u.shape, u.shape) |
| 51 | + self.assertTrue(torch.all(new_u >= 0) and torch.all(new_u <= 1)) |
| 52 | + |
| 53 | + # Test with custom step size |
| 54 | + custom_step_size = 0.5 |
| 55 | + new_u_custom = random_walk(dim, device, dtype, u, step_size=custom_step_size) |
| 56 | + self.assertEqual(new_u_custom.shape, u.shape) |
| 57 | + self.assertTrue(torch.all(new_u_custom >= 0) and torch.all(new_u_custom <= 1)) |
| 58 | + |
| 59 | + def test_gaussian(self): |
| 60 | + # Test gaussian function with default parameters |
| 61 | + dim = 3 |
| 62 | + device = "cpu" |
| 63 | + dtype = torch.float32 |
| 64 | + u = torch.rand(dim, device=device, dtype=dtype) |
| 65 | + |
| 66 | + mean = torch.zeros_like(u) |
| 67 | + std = torch.ones_like(u) |
| 68 | + |
| 69 | + new_u = gaussian(dim, device, dtype, u, mean=mean, std=std) |
| 70 | + |
| 71 | + # Validate output shape |
| 72 | + self.assertEqual(new_u.shape, u.shape) |
| 73 | + |
| 74 | + # Test with custom mean and std |
| 75 | + custom_mean = torch.full_like(u, 0.5) |
| 76 | + custom_std = torch.full_like(u, 0.1) |
| 77 | + new_u_custom = gaussian(dim, device, dtype, u, mean=custom_mean, std=custom_std) |
| 78 | + |
| 79 | + self.assertEqual(new_u_custom.shape, u.shape) |
| 80 | + self.assertTrue(torch.all(new_u_custom > custom_mean - 3 * custom_std)) |
| 81 | + self.assertTrue(torch.all(new_u_custom < custom_mean + 3 * custom_std)) |
| 82 | + |
37 | 83 |
|
38 | 84 | class TestIntegrator(unittest.TestCase):
|
39 | 85 | def setUp(self):
|
@@ -81,6 +127,18 @@ def test_initialization_with_maps(self, mock_linear_map, mock_composite_map):
|
81 | 127 | self.assertTrue(hasattr(integrator.maps, "device"))
|
82 | 128 | self.assertTrue(hasattr(integrator.maps, "dtype"))
|
83 | 129 |
|
| 130 | + integrator = Integrator( |
| 131 | + bounds=self.bounds, |
| 132 | + f=self.f, |
| 133 | + maps=mock_map, |
| 134 | + batch_size=self.batch_size, |
| 135 | + device="cpu", |
| 136 | + ) |
| 137 | + |
| 138 | + # Assertions |
| 139 | + self.assertEqual(integrator.device, "cpu") |
| 140 | + self.assertTrue(hasattr(integrator.maps, "device")) |
| 141 | + |
84 | 142 | def test_bounds_conversion(self):
|
85 | 143 | # Test various input types
|
86 | 144 | test_cases = [
|
@@ -223,6 +281,27 @@ def test_batch_size_handling(self):
|
223 | 281 | # Should not raise warning
|
224 | 282 | self.mc(neval=neval, nblock=nblock)
|
225 | 283 |
|
| 284 | + def test_block_size_warning(self): |
| 285 | + mc = MonteCarlo(bounds=self.bounds, f=self.simple_integral, batch_size=1000) |
| 286 | + with self.assertWarns(UserWarning): |
| 287 | + mc(neval=500, nblock=10) # neval too small for nblock |
| 288 | + |
| 289 | + def test_varying_nblock(self): |
| 290 | + test_cases = [ |
| 291 | + (10000, 10), # Standard case |
| 292 | + (10000, 1), # Single block |
| 293 | + (10000, 100), # Many blocks |
| 294 | + ] |
| 295 | + |
| 296 | + for neval, nblock in test_cases: |
| 297 | + with self.subTest(neval=neval, nblock=nblock): |
| 298 | + result = self.mc(neval=neval, nblock=nblock) |
| 299 | + if hasattr(result, "mean"): |
| 300 | + value = result.mean |
| 301 | + else: |
| 302 | + value = result |
| 303 | + self.assertAlmostEqual(float(value), 1.0, delta=0.1) |
| 304 | + |
226 | 305 |
|
227 | 306 | class TestMarkovChainMonteCarlo(unittest.TestCase):
|
228 | 307 | def setUp(self):
|
@@ -293,6 +372,41 @@ def test_burnin_effect(self):
|
293 | 372 | value = result
|
294 | 373 | self.assertAlmostEqual(float(value), 1.0, delta=tolerance)
|
295 | 374 |
|
| 375 | + def test_sample_acceptance(self): |
| 376 | + config = Configuration( |
| 377 | + self.mcmc.batch_size, |
| 378 | + self.mcmc.dim, |
| 379 | + self.mcmc.f_dim, |
| 380 | + self.mcmc.device, |
| 381 | + self.mcmc.dtype, |
| 382 | + ) |
| 383 | + config.u, config.detJ = self.mcmc.q0.sample_with_detJ(self.mcmc.batch_size) |
| 384 | + config.x, detj = self.mcmc.maps.forward_with_detJ(config.u) |
| 385 | + config.detJ *= detj |
| 386 | + config.weight = torch.rand(self.mcmc.batch_size, device=self.mcmc.device) |
| 387 | + |
| 388 | + self.mcmc.sample(config, nsteps=1, mix_rate=0.5) |
| 389 | + |
| 390 | + # Validate acceptance logic |
| 391 | + self.assertTrue(torch.all(config.weight >= EPSILON)) |
| 392 | + self.assertEqual(config.u.shape, config.x.shape) |
| 393 | + |
| 394 | + def test_varying_mix_rate(self): |
| 395 | + test_cases = [ |
| 396 | + (0.1, 0.2), # Low mix rate |
| 397 | + (0.5, 0.1), # Medium mix rate |
| 398 | + (0.9, 0.05), # High mix rate |
| 399 | + ] |
| 400 | + |
| 401 | + for mix_rate, tolerance in test_cases: |
| 402 | + with self.subTest(mix_rate=mix_rate): |
| 403 | + result = self.mcmc(neval=50000, mix_rate=mix_rate, nblock=10) |
| 404 | + if hasattr(result, "mean"): |
| 405 | + value = result.mean |
| 406 | + else: |
| 407 | + value = result |
| 408 | + self.assertAlmostEqual(float(value), 1.0, delta=tolerance) |
| 409 | + |
296 | 410 |
|
297 | 411 | class TestDistributedFunctionality(unittest.TestCase):
|
298 | 412 | @unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
|
|
0 commit comments