Skip to content

Commit 481c53f

Browse files
committed
fix sample
1 parent c720bd2 commit 481c53f

File tree

1 file changed

+22
-31
lines changed

1 file changed

+22
-31
lines changed

MCintegration/integrators.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,11 @@ def __init__(
125125
self.device = device
126126

127127
if isinstance(bounds, (list, np.ndarray)):
128-
self.bounds = torch.tensor(
129-
bounds, dtype=self.dtype, device=self.device)
128+
self.bounds = torch.tensor(bounds, dtype=self.dtype, device=self.device)
130129
elif isinstance(bounds, torch.Tensor):
131130
self.bounds = bounds.to(dtype=self.dtype, device=self.device)
132131
else:
133-
raise TypeError(
134-
"bounds must be a list, NumPy array, or torch.Tensor.")
132+
raise TypeError("bounds must be a list, NumPy array, or torch.Tensor.")
135133

136134
assert self.bounds.shape[1] == 2, "bounds must be a 2D array"
137135

@@ -177,13 +175,18 @@ def __call__(self, **kwargs):
177175
def sample(self, config, **kwargs):
178176
"""
179177
Generate samples for integration.
180-
This should be implemented by subclasses.
181178
182179
Args:
183180
config (Configuration): The configuration object to store samples
184181
**kwargs: Additional parameters
185182
"""
186-
raise NotImplementedError("Subclasses must implement this method.")
183+
config.u, config.detJ = self.q0.sample_with_detJ(config.batch_size)
184+
if not self.maps:
185+
config.x[:] = config.u
186+
else:
187+
config.x[:], detj = self.maps.forward_with_detJ(config.u)
188+
config.detJ *= detj
189+
self.f(config.x, config.fx)
187190

188191
def statistics(self, means, vars, neval=None):
189192
"""
@@ -212,11 +215,9 @@ def statistics(self, means, vars, neval=None):
212215
gathered_vars = [
213216
torch.zeros_like(vars) for _ in range(self.world_size)
214217
]
215-
dist.gather(means, gathered_means if self.rank ==
216-
0 else None, dst=0)
218+
dist.gather(means, gathered_means if self.rank == 0 else None, dst=0)
217219
if weighted:
218-
dist.gather(vars, gathered_vars if self.rank ==
219-
0 else None, dst=0)
220+
dist.gather(vars, gathered_vars if self.rank == 0 else None, dst=0)
220221

221222
if self.rank == 0:
222223
results = np.array([RAvg() for _ in range(f_dim)])
@@ -237,7 +238,7 @@ def statistics(self, means, vars, neval=None):
237238
nblock_total, dtype=self.dtype, device=self.device
238239
)
239240
for igpu in range(self.world_size):
240-
_means[igpu * nblock: (igpu + 1) * nblock] = (
241+
_means[igpu * nblock : (igpu + 1) * nblock] = (
241242
gathered_means[igpu][:, i]
242243
)
243244
results[i].update(
@@ -341,8 +342,7 @@ def __call__(self, neval, nblock=32, verbose=-1, **kwargs):
341342
integ_values = torch.zeros(
342343
(self.batch_size, self.f_dim), dtype=self.dtype, device=self.device
343344
)
344-
means = torch.zeros((nblock, self.f_dim),
345-
dtype=self.dtype, device=self.device)
345+
means = torch.zeros((nblock, self.f_dim), dtype=self.dtype, device=self.device)
346346
vars = torch.zeros_like(means)
347347

348348
for iblock in range(nblock):
@@ -354,8 +354,7 @@ def __call__(self, neval, nblock=32, verbose=-1, **kwargs):
354354
vars[iblock, :] = integ_values.var(dim=0) / self.batch_size
355355
integ_values.zero_()
356356

357-
results = self.statistics(
358-
means, vars, epoch_perblock * self.batch_size)
357+
results = self.statistics(means, vars, epoch_perblock * self.batch_size)
359358

360359
if self.rank == 0:
361360
if self.f_dim == 1:
@@ -382,8 +381,7 @@ def random_walk(dim, device, dtype, u, **kwargs):
382381
"""
383382
step_size = kwargs.get("step_size", 0.2)
384383
step_sizes = torch.ones(dim, device=device) * step_size
385-
step = torch.empty(dim, device=device,
386-
dtype=dtype).uniform_(-1, 1) * step_sizes
384+
step = torch.empty(dim, device=device, dtype=dtype).uniform_(-1, 1) * step_sizes
387385
new_u = (u + step) % 1.0
388386
return new_u
389387

@@ -497,8 +495,7 @@ def sample(self, config, nsteps=1, mix_rate=0.5, **kwargs):
497495
acceptance_probs = new_weight / config.weight * new_detJ / config.detJ
498496

499497
accept = (
500-
torch.rand(self.batch_size, dtype=self.dtype,
501-
device=self.device)
498+
torch.rand(self.batch_size, dtype=self.dtype, device=self.device)
502499
<= acceptance_probs
503500
)
504501

@@ -560,8 +557,7 @@ def __call__(
560557
config.x, detj = self.maps.forward_with_detJ(config.u)
561558
config.detJ *= detj
562559
config.weight = (
563-
mix_rate / config.detJ + (1 - mix_rate) *
564-
self.f(config.x, config.fx).abs_()
560+
mix_rate / config.detJ + (1 - mix_rate) * self.f(config.x, config.fx).abs_()
565561
)
566562
config.weight.masked_fill_(config.weight < EPSILON, EPSILON)
567563

@@ -571,14 +567,11 @@ def __call__(
571567
values = torch.zeros(
572568
(self.batch_size, self.f_dim), dtype=self.dtype, device=self.device
573569
)
574-
refvalues = torch.zeros(
575-
self.batch_size, dtype=self.dtype, device=self.device)
570+
refvalues = torch.zeros(self.batch_size, dtype=self.dtype, device=self.device)
576571

577-
means = torch.zeros((nblock, self.f_dim),
578-
dtype=self.dtype, device=self.device)
572+
means = torch.zeros((nblock, self.f_dim), dtype=self.dtype, device=self.device)
579573
vars = torch.zeros_like(means)
580-
means_ref = torch.zeros(
581-
(nblock, 1), dtype=self.dtype, device=self.device)
574+
means_ref = torch.zeros((nblock, 1), dtype=self.dtype, device=self.device)
582575
vars_ref = torch.zeros_like(means_ref)
583576

584577
for iblock in range(nblock):
@@ -588,17 +581,15 @@ def __call__(
588581

589582
config.fx.div_(config.weight.unsqueeze(1))
590583
values += config.fx / n_meas_perblock
591-
refvalues += 1 / \
592-
(config.detJ * config.weight) / n_meas_perblock
584+
refvalues += 1 / (config.detJ * config.weight) / n_meas_perblock
593585
means[iblock, :] = values.mean(dim=0)
594586
vars[iblock, :] = values.var(dim=0) / self.batch_size
595587
means_ref[iblock, 0] = refvalues.mean()
596588
vars_ref[iblock, 0] = refvalues.var() / self.batch_size
597589
values.zero_()
598590
refvalues.zero_()
599591

600-
results_unnorm = self.statistics(
601-
means, vars, nsteps_perblock * self.batch_size)
592+
results_unnorm = self.statistics(means, vars, nsteps_perblock * self.batch_size)
602593
results_ref = self.statistics(
603594
means_ref, vars_ref, nsteps_perblock * self.batch_size
604595
)

0 commit comments

Comments
 (0)