Skip to content

Commit bdc2655

Browse files
committed
some cleanup
1 parent 6c551ab commit bdc2655

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.10"
3+
version = "0.0.11"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -23,6 +23,7 @@ classifiers=[
2323

2424
dependencies = [
2525
'accelerate',
26+
'einx>=0.3.0',
2627
'einops>=0.8.0',
2728
'ema-pytorch>=0.5.1',
2829
'pillow',

rectified_flow_pytorch/rectified_flow.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchvision.utils import save_image
1515
from torchvision.models import VGG16_Weights
1616

17+
import einx
1718
from einops import einsum, reduce, rearrange, repeat
1819
from einops.layers.torch import Rearrange
1920

@@ -348,7 +349,7 @@ def forward(self, x):
348349
half_dim = self.dim // 2
349350
emb = math.log(self.theta) / (half_dim - 1)
350351
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
351-
emb = x[:, None] * emb[None, :]
352+
emb = einx.multiply('i, j -> i j', x, emb)
352353
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
353354
return emb
354355

@@ -440,9 +441,9 @@ def forward(self, x):
440441
x = self.norm(x)
441442

442443
qkv = self.to_qkv(x).chunk(3, dim = 1)
443-
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
444+
q, k, v = tuple(rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads) for t in qkv)
444445

445-
mk, mv = map(lambda t: repeat(t, 'h c n -> b h c n', b = b), self.mem_kv)
446+
mk, mv = tuple(repeat(t, 'h c n -> b h c n', b = b) for t in self.mem_kv)
446447
k, v = map(partial(torch.cat, dim = -1), ((mk, k), (mv, v)))
447448

448449
q = q.softmax(dim = -2)
@@ -474,7 +475,7 @@ def __init__(
474475

475476
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
476477
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
477-
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
478+
self.to_out = nn.Conv2d(hidden_dim, dim, 1, bias = False)
478479

479480
def forward(self, x):
480481
b, c, h, w = x.shape
@@ -493,7 +494,6 @@ def forward(self, x):
493494
attn = sim.softmax(dim = -1)
494495
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
495496

496-
497497
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
498498
return self.to_out(out)
499499

@@ -505,7 +505,7 @@ def __init__(
505505
dim,
506506
init_dim = None,
507507
out_dim = None,
508-
dim_mults = (1, 2, 4, 8),
508+
dim_mults: Tuple[int, ...] = (1, 2, 4, 8),
509509
channels = 3,
510510
learned_variance = False,
511511
learned_sinusoidal_cond = False,
@@ -612,7 +612,7 @@ def __init__(
612612
def downsample_factor(self):
613613
return 2 ** (len(self.downs) - 1)
614614

615-
def forward(self, x, times, x_self_cond = None):
615+
def forward(self, x, times):
616616
assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet'
617617

618618
x = self.init_conv(x)

0 commit comments

Comments
 (0)