From 580274be79c4cb11e1fe1ef9dcaaeba9c7be669a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 7 Mar 2023 12:41:55 -0800 Subject: [PATCH] use .to(device) to avoid copy, within one_unet_in_gpu context --- dalle2_pytorch/dalle2_pytorch.py | 15 +++++++++++---- dalle2_pytorch/version.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 1f94304b..c0211fc1 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2727,11 +2727,16 @@ def one_unet_in_gpu(self, unet_number = None, unet = None): if exists(unet_number): unet = self.get_unet(unet_number) + # devices + + cuda, cpu = torch.device('cuda'), torch.device('cpu') + self.cuda() devices = [module_device(unet) for unet in self.unets] - self.unets.cpu() - unet.cuda() + + self.unets.to(cpu) + unet.to(cuda) yield @@ -3114,7 +3119,8 @@ def sample( distributed = False, inpaint_image = None, inpaint_mask = None, - inpaint_resample_times = 5 + inpaint_resample_times = 5, + one_unet_in_gpu_at_time = True ): assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' @@ -3137,6 +3143,7 @@ def sample( assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size) prev_unet_output_size = self.image_sizes[start_at_unet_number - 2] img = resize_image_to(image, prev_unet_output_size, nearest = True) + is_cuda = next(self.parameters()).is_cuda num_unets = self.num_unets @@ -3146,7 +3153,7 @@ def sample( if unet_number < start_at_unet_number: continue # It's the easiest way to do it - context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context() + context = self.one_unet_in_gpu(unet = unet) if is_cuda and one_unet_in_gpu_at_time else null_context() with context: # prepare low resolution conditioning for upsamplers diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 3081afba..aba17a1d 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.12.3' +__version__ = '1.12.4'