Skip to content

Commit 444c112

Browse files
committed
fix device issue with noise
1 parent d60cc24 commit 444c112

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.7"
3+
version = "0.0.8"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def ode_fn(t, x):
182182

183183
# start with random gaussian noise - y0
184184

185-
noise = default(noise, torch.randn((batch_size, *data_shape)))
185+
noise = default(noise, torch.randn((batch_size, *data_shape), device = self.device))
186186

187187
# time steps
188188

@@ -280,6 +280,9 @@ def __init__(
280280

281281
self.frozen_model = frozen_model
282282

283+
def device(self):
284+
return next(self.parameters()).device
285+
283286
def parameters(self):
284287
return self.model.parameters() # omit frozen model
285288

@@ -288,7 +291,7 @@ def sample(self, *args, **kwargs):
288291

289292
def forward(self):
290293

291-
noise = torch.randn((self.batch_size, *self.data_shape))
294+
noise = torch.randn((self.batch_size, *self.data_shape), device = self.device)
292295
sampled_output = self.frozen_model.sample(noise = noise)
293296

294297
# the coupling in the paper is (noise, sampled_output)

0 commit comments

Comments
 (0)