Skip to content

Commit 3b8ddbf

Browse files
committed
update integrators_test
1 parent 56acac2 commit 3b8ddbf

File tree

2 files changed

+55
-52
lines changed

2 files changed

+55
-52
lines changed

MCintegration/integrators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def __call__(self, neval, nblock=32, verbose=-1, **kwargs):
233233

234234
if verbose > 0:
235235
print(
236-
f"nblock = {nblock}, n_steps_perblock = {epoch_perblock}, batch_size = {self.batch_size}, actual neval = {self.batch_size*epoch_perblock*nblock}"
236+
f"nblock = {nblock}, n_steps_perblock = {epoch_perblock}, batch_size = {self.batch_size}, actual neval = {self.batch_size * epoch_perblock * nblock}"
237237
)
238238

239239
config = Configuration(
@@ -356,13 +356,13 @@ def __call__(
356356
else:
357357
nblock = epoch // nsteps_perblock
358358
n_meas_perblock = nsteps_perblock // meas_freq
359-
assert (
360-
n_meas_perblock > 0
361-
), f"neval ({neval}) should be larger than batch_size * nblock * meas_freq ({self.batch_size} * {nblock} * {meas_freq})"
359+
assert n_meas_perblock > 0, (
360+
f"neval ({neval}) should be larger than batch_size * nblock * meas_freq ({self.batch_size} * {nblock} * {meas_freq})"
361+
)
362362

363363
if verbose > 0:
364364
print(
365-
f"nblock = {nblock}, n_meas_perblock = {n_meas_perblock}, meas_freq = {meas_freq}, batch_size = {self.batch_size}, actual neval = {self.batch_size*nsteps_perblock*nblock}"
365+
f"nblock = {nblock}, n_meas_perblock = {n_meas_perblock}, meas_freq = {meas_freq}, batch_size = {self.batch_size}, actual neval = {self.batch_size * nsteps_perblock * nblock}"
366366
)
367367

368368
config = Configuration(

MCintegration/integrators_test.py

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import unittest
2-
from unittest.mock import patch
2+
from unittest.mock import patch, MagicMock
33
import os
44
import torch
55
import numpy as np
66
from integrators import Integrator, MonteCarlo, MarkovChainMonteCarlo
77
from integrators import get_ip, get_open_port, setup
88

9-
# from base import LinearMap, Uniform
109
from maps import Configuration
1110

1211

@@ -54,6 +53,34 @@ def test_initialization(self):
5453
self.assertTrue(hasattr(integrator.maps, "device"))
5554
self.assertTrue(hasattr(integrator.maps, "dtype"))
5655

56+
@patch("MCintegration.maps.CompositeMap")
57+
@patch("MCintegration.base.LinearMap")
58+
def test_initialization_with_maps(self, mock_linear_map, mock_composite_map):
59+
# Mock the LinearMap and CompositeMap
60+
mock_linear_map_instance = MagicMock()
61+
mock_linear_map.return_value = mock_linear_map_instance
62+
mock_composite_map_instance = MagicMock()
63+
mock_composite_map.return_value = mock_composite_map_instance
64+
65+
# Create a mock map
66+
mock_map = MagicMock()
67+
mock_map.device = "cpu"
68+
mock_map.dtype = torch.float32
69+
mock_map.forward_with_detJ.return_value = (torch.rand(10, 2), torch.rand(10))
70+
71+
# Initialize Integrator with maps
72+
integrator = Integrator(
73+
bounds=self.bounds, f=self.f, maps=mock_map, batch_size=self.batch_size
74+
)
75+
76+
# Assertions
77+
self.assertEqual(integrator.dim, 2)
78+
self.assertEqual(integrator.batch_size, 1000)
79+
self.assertEqual(integrator.f_dim, 1)
80+
self.assertTrue(hasattr(integrator.maps, "forward_with_detJ"))
81+
self.assertTrue(hasattr(integrator.maps, "device"))
82+
self.assertTrue(hasattr(integrator.maps, "dtype"))
83+
5784
def test_bounds_conversion(self):
5885
# Test various input types
5986
test_cases = [
@@ -98,11 +125,11 @@ def test_invalid_bounds(self):
98125
with self.assertRaises(error_type):
99126
Integrator(bounds=bounds, f=self.f)
100127

101-
def test_device_handling(self):
102-
if torch.cuda.is_available():
103-
integrator = Integrator(bounds=self.bounds, f=self.f, device="cuda")
104-
self.assertTrue(integrator.bounds.is_cuda)
105-
self.assertTrue(integrator.maps.device == "cuda")
128+
# def test_device_handling(self):
129+
# if torch.cuda.is_available():
130+
# integrator = Integrator(bounds=self.bounds, f=self.f, device="cuda")
131+
# self.assertTrue(integrator.bounds.is_cuda)
132+
# self.assertTrue(integrator.maps.device == "cuda")
106133

107134
def test_dtype_handling(self):
108135
dtypes = [torch.float32, torch.float64]
@@ -266,30 +293,6 @@ def test_burnin_effect(self):
266293
value = result
267294
self.assertAlmostEqual(float(value), 1.0, delta=tolerance)
268295

269-
# def test_mix_rate_sensitivity(self):
270-
# # Modified mix rate test to be more robust
271-
# mix_rates = [0.0, 0.5, 1.0]
272-
# results = []
273-
274-
# for mix_rate in mix_rates:
275-
# accumulated_error = 0
276-
# n_trials = 3 # Run multiple trials for each mix_rate
277-
278-
# for _ in range(n_trials):
279-
# result = self.mcmc(neval=50000, mix_rate=mix_rate, nblock=10)
280-
# if hasattr(result, "mean"):
281-
# value = result.mean
282-
# error = result.sdev
283-
# else:
284-
# value = result
285-
# error = abs(float(value) - 1.0)
286-
# accumulated_error += error
287-
288-
# results.append(accumulated_error / n_trials)
289-
290-
# # We expect moderate mix rates to have lower average error
291-
# self.assertLess(results[1], max(results[0], results[2]))
292-
293296

294297
class TestDistributedFunctionality(unittest.TestCase):
295298
@unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
@@ -300,26 +303,26 @@ def test_distributed_initialization(self):
300303
self.assertEqual(integrator.rank, 0)
301304
self.assertEqual(integrator.world_size, 1)
302305

303-
@unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
304-
def test_multi_gpu_consistency(self):
305-
if torch.cuda.device_count() >= 2:
306-
bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
307-
f = lambda x, fx: torch.ones_like(x)
306+
# @unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
307+
# def test_multi_gpu_consistency(self):
308+
# if torch.cuda.device_count() >= 2:
309+
# bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
310+
# f = lambda x, fx: torch.ones_like(x)
308311

309-
# Create two integrators on different devices
310-
integrator1 = Integrator(bounds=bounds, f=f, device="cuda:0")
311-
integrator2 = Integrator(bounds=bounds, f=f, device="cuda:1")
312+
# # Create two integrators on different devices
313+
# integrator1 = Integrator(bounds=bounds, f=f, device="cuda:0")
314+
# integrator2 = Integrator(bounds=bounds, f=f, device="cuda:1")
312315

313-
# Results should be consistent across devices
314-
result1 = integrator1(neval=10000)
315-
result2 = integrator2(neval=10000)
316+
# # Results should be consistent across devices
317+
# result1 = integrator1(neval=10000)
318+
# result2 = integrator2(neval=10000)
316319

317-
if hasattr(result1, "mean"):
318-
value1, value2 = result1.mean, result2.mean
319-
else:
320-
value1, value2 = result1, result2
320+
# if hasattr(result1, "mean"):
321+
# value1, value2 = result1.mean, result2.mean
322+
# else:
323+
# value1, value2 = result1, result2
321324

322-
self.assertAlmostEqual(float(value1), float(value2), places=1)
325+
# self.assertAlmostEqual(float(value1), float(value2), places=1)
323326

324327

325328
if __name__ == "__main__":

0 commit comments

Comments
 (0)