Skip to content

Commit 3be402d

Browse files
committed
update multicpu/gpu test
1 parent 38f2018 commit 3be402d

File tree

2 files changed

+106
-57
lines changed

2 files changed

+106
-57
lines changed

MCintegration/mc_multicpu_test.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

examples/mc_multicpu_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
import torch.distributed as dist
3+
import torch.multiprocessing as mp
4+
import os
5+
import traceback
6+
from integrators import MonteCarlo, MarkovChainMonteCarlo
7+
8+
# Set environment variables before spawning processes
9+
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
10+
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
11+
12+
backend = "gloo"
13+
14+
15+
def init_process(rank, world_size, fn, backend=backend):
16+
try:
17+
# Initialize the process group
18+
dist.init_process_group(backend, rank=rank, world_size=world_size)
19+
# Call the function
20+
fn(rank, world_size)
21+
except Exception as e:
22+
print(f"Error in process {rank}: {e}")
23+
traceback.print_exc()
24+
# Make sure to clean up
25+
if dist.is_initialized():
26+
dist.destroy_process_group()
27+
# Return non-zero to indicate error
28+
raise e
29+
30+
31+
def run_mcmc(rank, world_size):
32+
print(world_size)
33+
try:
34+
# Set seed for reproducibility but different for each process
35+
torch.manual_seed(42 + rank)
36+
37+
# Instantiate the MarkovChainMonteCarlo class
38+
bounds = [(-1, 1), (-1, 1)]
39+
# n_eval = 8000000 // world_size # Divide evaluations among processes
40+
n_eval = 8000000
41+
batch_size = 10000
42+
n_therm = 20
43+
44+
# Define the function to be integrated (dummy example)
45+
def two_integrands(x, f):
46+
f[:, 0] = (x[:, 0] ** 2 + x[:, 1] ** 2 < 1).double()
47+
f[:, 1] = torch.clamp(1 - (x[:, 0] ** 2 + x[:, 1] ** 2), min=0) * 2
48+
return f.mean(dim=-1)
49+
50+
# Choose device based on availability and rank
51+
if torch.cuda.is_available() and torch.cuda.device_count() > world_size:
52+
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
53+
else:
54+
device = torch.device("cpu")
55+
56+
print(f"Process {rank} using device: {device}")
57+
58+
mcmc = MarkovChainMonteCarlo(
59+
bounds=bounds,
60+
f=two_integrands,
61+
f_dim=2,
62+
batch_size=batch_size,
63+
nburnin=n_therm,
64+
device=device,
65+
)
66+
67+
# Call the MarkovChainMonteCarlo method
68+
mcmc_result = mcmc(n_eval)
69+
70+
if rank == 0:
71+
print("MarkovChainMonteCarlo Result:", mcmc_result)
72+
73+
except Exception as e:
74+
print(f"Error in run_mcmc for rank {rank}: {e}")
75+
traceback.print_exc()
76+
raise e
77+
finally:
78+
# Clean up
79+
if dist.is_initialized():
80+
dist.destroy_process_group()
81+
82+
83+
def test_mcmc(world_size):
84+
# Use fewer processes than CPU cores to avoid resource contention
85+
world_size = min(world_size, mp.cpu_count())
86+
print(f"Starting with {world_size} processes")
87+
88+
# Start processes with proper error handling
89+
try:
90+
mp.spawn(
91+
init_process,
92+
args=(world_size, run_mcmc),
93+
nprocs=world_size,
94+
join=True,
95+
daemon=False,
96+
)
97+
except Exception as e:
98+
print(f"Error in test_mcmc: {e}")
99+
# Make sure all processes are terminated
100+
# This is handled automatically by spawn when join=True
101+
102+
103+
if __name__ == "__main__":
104+
# Prevent issues with multiprocessing on some platforms
105+
mp.set_start_method("spawn", force=True)
106+
test_mcmc(8)

0 commit comments

Comments
 (0)