Skip to content

Commit 56acac2

Browse files
committed
add integrators_utils test
1 parent d8eac3d commit 56acac2

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

MCintegration/integrators_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,41 @@
11
import unittest
2+
from unittest.mock import patch
3+
import os
24
import torch
35
import numpy as np
46
from integrators import Integrator, MonteCarlo, MarkovChainMonteCarlo
7+
from integrators import get_ip, get_open_port, setup
58

69
# from base import LinearMap, Uniform
710
from maps import Configuration
811

912

13+
class TestIntegrators(unittest.TestCase):
14+
@patch("socket.socket")
15+
def test_get_ip(self, mock_socket):
16+
# Mock the socket behavior
17+
mock_socket_instance = mock_socket.return_value
18+
mock_socket_instance.getsockname.return_value = ("192.168.1.1", 12345)
19+
ip = get_ip()
20+
self.assertEqual(ip, "192.168.1.1")
21+
22+
@patch("socket.socket")
23+
def test_get_open_port(self, mock_socket):
24+
# Mock the socket behavior
25+
mock_socket_instance = mock_socket.return_value.__enter__.return_value
26+
mock_socket_instance.getsockname.return_value = ("0.0.0.0", 54321)
27+
port = get_open_port()
28+
self.assertEqual(port, 54321)
29+
30+
@patch("torch.distributed.init_process_group")
31+
@patch("torch.cuda.set_device")
32+
@patch.dict(os.environ, {"LOCAL_RANK": "0"})
33+
def test_setup(self, mock_set_device, mock_init_process_group):
34+
setup(backend="gloo")
35+
mock_init_process_group.assert_called_once_with(backend="gloo")
36+
mock_set_device.assert_called_once_with(0)
37+
38+
1039
class TestIntegrator(unittest.TestCase):
1140
def setUp(self):
1241
self.bounds = torch.tensor([[0.0, 1.0], [-1.0, 1.0]], dtype=torch.float64)

0 commit comments

Comments
 (0)