@@ -125,13 +125,11 @@ def __init__(
125
125
self .device = device
126
126
127
127
if isinstance (bounds , (list , np .ndarray )):
128
- self .bounds = torch .tensor (
129
- bounds , dtype = self .dtype , device = self .device )
128
+ self .bounds = torch .tensor (bounds , dtype = self .dtype , device = self .device )
130
129
elif isinstance (bounds , torch .Tensor ):
131
130
self .bounds = bounds .to (dtype = self .dtype , device = self .device )
132
131
else :
133
- raise TypeError (
134
- "bounds must be a list, NumPy array, or torch.Tensor." )
132
+ raise TypeError ("bounds must be a list, NumPy array, or torch.Tensor." )
135
133
136
134
assert self .bounds .shape [1 ] == 2 , "bounds must be a 2D array"
137
135
@@ -177,13 +175,18 @@ def __call__(self, **kwargs):
177
175
def sample (self , config , ** kwargs ):
178
176
"""
179
177
Generate samples for integration.
180
- This should be implemented by subclasses.
181
178
182
179
Args:
183
180
config (Configuration): The configuration object to store samples
184
181
**kwargs: Additional parameters
185
182
"""
186
- raise NotImplementedError ("Subclasses must implement this method." )
183
+ config .u , config .detJ = self .q0 .sample_with_detJ (config .batch_size )
184
+ if not self .maps :
185
+ config .x [:] = config .u
186
+ else :
187
+ config .x [:], detj = self .maps .forward_with_detJ (config .u )
188
+ config .detJ *= detj
189
+ self .f (config .x , config .fx )
187
190
188
191
def statistics (self , means , vars , neval = None ):
189
192
"""
@@ -212,11 +215,9 @@ def statistics(self, means, vars, neval=None):
212
215
gathered_vars = [
213
216
torch .zeros_like (vars ) for _ in range (self .world_size )
214
217
]
215
- dist .gather (means , gathered_means if self .rank ==
216
- 0 else None , dst = 0 )
218
+ dist .gather (means , gathered_means if self .rank == 0 else None , dst = 0 )
217
219
if weighted :
218
- dist .gather (vars , gathered_vars if self .rank ==
219
- 0 else None , dst = 0 )
220
+ dist .gather (vars , gathered_vars if self .rank == 0 else None , dst = 0 )
220
221
221
222
if self .rank == 0 :
222
223
results = np .array ([RAvg () for _ in range (f_dim )])
@@ -237,7 +238,7 @@ def statistics(self, means, vars, neval=None):
237
238
nblock_total , dtype = self .dtype , device = self .device
238
239
)
239
240
for igpu in range (self .world_size ):
240
- _means [igpu * nblock : (igpu + 1 ) * nblock ] = (
241
+ _means [igpu * nblock : (igpu + 1 ) * nblock ] = (
241
242
gathered_means [igpu ][:, i ]
242
243
)
243
244
results [i ].update (
@@ -341,8 +342,7 @@ def __call__(self, neval, nblock=32, verbose=-1, **kwargs):
341
342
integ_values = torch .zeros (
342
343
(self .batch_size , self .f_dim ), dtype = self .dtype , device = self .device
343
344
)
344
- means = torch .zeros ((nblock , self .f_dim ),
345
- dtype = self .dtype , device = self .device )
345
+ means = torch .zeros ((nblock , self .f_dim ), dtype = self .dtype , device = self .device )
346
346
vars = torch .zeros_like (means )
347
347
348
348
for iblock in range (nblock ):
@@ -354,8 +354,7 @@ def __call__(self, neval, nblock=32, verbose=-1, **kwargs):
354
354
vars [iblock , :] = integ_values .var (dim = 0 ) / self .batch_size
355
355
integ_values .zero_ ()
356
356
357
- results = self .statistics (
358
- means , vars , epoch_perblock * self .batch_size )
357
+ results = self .statistics (means , vars , epoch_perblock * self .batch_size )
359
358
360
359
if self .rank == 0 :
361
360
if self .f_dim == 1 :
@@ -382,8 +381,7 @@ def random_walk(dim, device, dtype, u, **kwargs):
382
381
"""
383
382
step_size = kwargs .get ("step_size" , 0.2 )
384
383
step_sizes = torch .ones (dim , device = device ) * step_size
385
- step = torch .empty (dim , device = device ,
386
- dtype = dtype ).uniform_ (- 1 , 1 ) * step_sizes
384
+ step = torch .empty (dim , device = device , dtype = dtype ).uniform_ (- 1 , 1 ) * step_sizes
387
385
new_u = (u + step ) % 1.0
388
386
return new_u
389
387
@@ -497,8 +495,7 @@ def sample(self, config, nsteps=1, mix_rate=0.5, **kwargs):
497
495
acceptance_probs = new_weight / config .weight * new_detJ / config .detJ
498
496
499
497
accept = (
500
- torch .rand (self .batch_size , dtype = self .dtype ,
501
- device = self .device )
498
+ torch .rand (self .batch_size , dtype = self .dtype , device = self .device )
502
499
<= acceptance_probs
503
500
)
504
501
@@ -560,8 +557,7 @@ def __call__(
560
557
config .x , detj = self .maps .forward_with_detJ (config .u )
561
558
config .detJ *= detj
562
559
config .weight = (
563
- mix_rate / config .detJ + (1 - mix_rate ) *
564
- self .f (config .x , config .fx ).abs_ ()
560
+ mix_rate / config .detJ + (1 - mix_rate ) * self .f (config .x , config .fx ).abs_ ()
565
561
)
566
562
config .weight .masked_fill_ (config .weight < EPSILON , EPSILON )
567
563
@@ -571,14 +567,11 @@ def __call__(
571
567
values = torch .zeros (
572
568
(self .batch_size , self .f_dim ), dtype = self .dtype , device = self .device
573
569
)
574
- refvalues = torch .zeros (
575
- self .batch_size , dtype = self .dtype , device = self .device )
570
+ refvalues = torch .zeros (self .batch_size , dtype = self .dtype , device = self .device )
576
571
577
- means = torch .zeros ((nblock , self .f_dim ),
578
- dtype = self .dtype , device = self .device )
572
+ means = torch .zeros ((nblock , self .f_dim ), dtype = self .dtype , device = self .device )
579
573
vars = torch .zeros_like (means )
580
- means_ref = torch .zeros (
581
- (nblock , 1 ), dtype = self .dtype , device = self .device )
574
+ means_ref = torch .zeros ((nblock , 1 ), dtype = self .dtype , device = self .device )
582
575
vars_ref = torch .zeros_like (means_ref )
583
576
584
577
for iblock in range (nblock ):
@@ -588,17 +581,15 @@ def __call__(
588
581
589
582
config .fx .div_ (config .weight .unsqueeze (1 ))
590
583
values += config .fx / n_meas_perblock
591
- refvalues += 1 / \
592
- (config .detJ * config .weight ) / n_meas_perblock
584
+ refvalues += 1 / (config .detJ * config .weight ) / n_meas_perblock
593
585
means [iblock , :] = values .mean (dim = 0 )
594
586
vars [iblock , :] = values .var (dim = 0 ) / self .batch_size
595
587
means_ref [iblock , 0 ] = refvalues .mean ()
596
588
vars_ref [iblock , 0 ] = refvalues .var () / self .batch_size
597
589
values .zero_ ()
598
590
refvalues .zero_ ()
599
591
600
- results_unnorm = self .statistics (
601
- means , vars , nsteps_perblock * self .batch_size )
592
+ results_unnorm = self .statistics (means , vars , nsteps_perblock * self .batch_size )
602
593
results_ref = self .statistics (
603
594
means_ref , vars_ref , nsteps_perblock * self .batch_size
604
595
)
0 commit comments