You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
out=einsum(attn, v, 'b h i j, b h j d -> b h i d')
495
496
496
-
497
497
out=rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
498
498
returnself.to_out(out)
499
499
@@ -505,7 +505,7 @@ def __init__(
505
505
dim,
506
506
init_dim=None,
507
507
out_dim=None,
508
-
dim_mults= (1, 2, 4, 8),
508
+
dim_mults: Tuple[int, ...]= (1, 2, 4, 8),
509
509
channels=3,
510
510
learned_variance=False,
511
511
learned_sinusoidal_cond=False,
@@ -612,7 +612,7 @@ def __init__(
612
612
defdownsample_factor(self):
613
613
return2** (len(self.downs) -1)
614
614
615
-
defforward(self, x, times, x_self_cond=None):
615
+
defforward(self, x, times):
616
616
assertall([divisible_by(d, self.downsample_factor) fordinx.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet'
0 commit comments