Skip to content
This repository was archived by the owner on Feb 24, 2025. It is now read-only.

Commit ecfea65

Browse files
committed
Fix metrics to work with grayscale datasets (#9)
1 parent 1d25833 commit ecfea65

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

metrics/metric_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_l
213213
# Main loop.
214214
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
215215
for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
216+
if images.shape[1] == 1:
217+
images = images.repeat([1, 3, 1, 1])
216218
features = detector(images.to(opts.device), **detector_kwargs)
217219
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
218220
progress.update(stats.num_items)
@@ -262,7 +264,10 @@ def run_generator(z, c):
262264
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
263265
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
264266
images.append(run_generator(z, c))
265-
features = detector(torch.cat(images), **detector_kwargs)
267+
images = torch.cat(images)
268+
if images.shape[1] == 1:
269+
images = images.repeat([1, 3, 1, 1])
270+
features = detector(images, **detector_kwargs)
266271
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
267272
progress.update(stats.num_items)
268273
return stats

0 commit comments

Comments
 (0)