@@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module):
119
119
Output channels
120
120
dropout (:obj:`float`, *optional*, defaults to 0.0):
121
121
Dropout rate
122
+ groups (:obj:`int`, *optional*, defaults to `32`):
123
+ The number of groups to use for group norm.
122
124
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
123
125
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
124
126
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module):
128
130
in_channels : int
129
131
out_channels : int = None
130
132
dropout : float = 0.0
133
+ groups : int = 32
131
134
use_nin_shortcut : bool = None
132
135
dtype : jnp .dtype = jnp .float32
133
136
134
137
def setup (self ):
135
138
out_channels = self .in_channels if self .out_channels is None else self .out_channels
136
139
137
- self .norm1 = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
140
+ self .norm1 = nn .GroupNorm (num_groups = self . groups , epsilon = 1e-6 )
138
141
self .conv1 = nn .Conv (
139
142
out_channels ,
140
143
kernel_size = (3 , 3 ),
@@ -143,7 +146,7 @@ def setup(self):
143
146
dtype = self .dtype ,
144
147
)
145
148
146
- self .norm2 = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
149
+ self .norm2 = nn .GroupNorm (num_groups = self . groups , epsilon = 1e-6 )
147
150
self .dropout_layer = nn .Dropout (self .dropout )
148
151
self .conv2 = nn .Conv (
149
152
out_channels ,
@@ -191,20 +194,23 @@ class FlaxAttentionBlock(nn.Module):
191
194
Input channels
192
195
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
193
196
Number of attention heads
197
+ num_groups (:obj:`int`, *optional*, defaults to `32`):
198
+ The number of groups to use for group norm
194
199
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
195
200
Parameters `dtype`
196
201
197
202
"""
198
203
channels : int
199
204
num_head_channels : int = None
205
+ num_groups : int = 32
200
206
dtype : jnp .dtype = jnp .float32
201
207
202
208
def setup (self ):
203
209
self .num_heads = self .channels // self .num_head_channels if self .num_head_channels is not None else 1
204
210
205
211
dense = partial (nn .Dense , self .channels , dtype = self .dtype )
206
212
207
- self .group_norm = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
213
+ self .group_norm = nn .GroupNorm (num_groups = self . num_groups , epsilon = 1e-6 )
208
214
self .query , self .key , self .value = dense (), dense (), dense ()
209
215
self .proj_attn = dense ()
210
216
@@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
264
270
Dropout rate
265
271
num_layers (:obj:`int`, *optional*, defaults to 1):
266
272
Number of Resnet layer block
273
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
274
+ The number of groups to use for the Resnet block group norm
267
275
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
268
276
Whether to add downsample layer
269
277
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
273
281
out_channels : int
274
282
dropout : float = 0.0
275
283
num_layers : int = 1
284
+ resnet_groups : int = 32
276
285
add_downsample : bool = True
277
286
dtype : jnp .dtype = jnp .float32
278
287
@@ -285,6 +294,7 @@ def setup(self):
285
294
in_channels = in_channels ,
286
295
out_channels = self .out_channels ,
287
296
dropout = self .dropout ,
297
+ groups = self .resnet_groups ,
288
298
dtype = self .dtype ,
289
299
)
290
300
resnets .append (res_block )
@@ -303,9 +313,9 @@ def __call__(self, hidden_states, deterministic=True):
303
313
return hidden_states
304
314
305
315
306
- class FlaxUpEncoderBlock2D (nn .Module ):
316
+ class FlaxUpDecoderBlock2D (nn .Module ):
307
317
r"""
308
- Flax Resnet blocks-based Encoder block for diffusion-based VAE.
318
+ Flax Resnet blocks-based Decoder block for diffusion-based VAE.
309
319
310
320
Parameters:
311
321
in_channels (:obj:`int`):
@@ -316,15 +326,18 @@ class FlaxUpEncoderBlock2D(nn.Module):
316
326
Dropout rate
317
327
num_layers (:obj:`int`, *optional*, defaults to 1):
318
328
Number of Resnet layer block
319
- add_downsample (:obj:`bool`, *optional*, defaults to `True`):
320
- Whether to add downsample layer
329
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
330
+ The number of groups to use for the Resnet block group norm
331
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
332
+ Whether to add upsample layer
321
333
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
322
334
Parameters `dtype`
323
335
"""
324
336
in_channels : int
325
337
out_channels : int
326
338
dropout : float = 0.0
327
339
num_layers : int = 1
340
+ resnet_groups : int = 32
328
341
add_upsample : bool = True
329
342
dtype : jnp .dtype = jnp .float32
330
343
@@ -336,6 +349,7 @@ def setup(self):
336
349
in_channels = in_channels ,
337
350
out_channels = self .out_channels ,
338
351
dropout = self .dropout ,
352
+ groups = self .resnet_groups ,
339
353
dtype = self .dtype ,
340
354
)
341
355
resnets .append (res_block )
@@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module):
366
380
Dropout rate
367
381
num_layers (:obj:`int`, *optional*, defaults to 1):
368
382
Number of Resnet layer block
383
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
384
+ The number of groups to use for the Resnet and Attention block group norm
369
385
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
370
386
Number of attention heads for each attention block
371
387
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module):
374
390
in_channels : int
375
391
dropout : float = 0.0
376
392
num_layers : int = 1
393
+ resnet_groups : int = 32
377
394
attn_num_head_channels : int = 1
378
395
dtype : jnp .dtype = jnp .float32
379
396
380
397
def setup (self ):
398
+ resnet_groups = self .resnet_groups if self .resnet_groups is not None else min (self .in_channels // 4 , 32 )
399
+
381
400
# there is always at least one resnet
382
401
resnets = [
383
402
FlaxResnetBlock2D (
384
403
in_channels = self .in_channels ,
385
404
out_channels = self .in_channels ,
386
405
dropout = self .dropout ,
406
+ groups = resnet_groups ,
387
407
dtype = self .dtype ,
388
408
)
389
409
]
@@ -392,14 +412,18 @@ def setup(self):
392
412
393
413
for _ in range (self .num_layers ):
394
414
attn_block = FlaxAttentionBlock (
395
- channels = self .in_channels , num_head_channels = self .attn_num_head_channels , dtype = self .dtype
415
+ channels = self .in_channels ,
416
+ num_head_channels = self .attn_num_head_channels ,
417
+ num_groups = resnet_groups ,
418
+ dtype = self .dtype ,
396
419
)
397
420
attentions .append (attn_block )
398
421
399
422
res_block = FlaxResnetBlock2D (
400
423
in_channels = self .in_channels ,
401
424
out_channels = self .in_channels ,
402
425
dropout = self .dropout ,
426
+ groups = resnet_groups ,
403
427
dtype = self .dtype ,
404
428
)
405
429
resnets .append (res_block )
@@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module):
441
465
Tuple containing the number of output channels for each block
442
466
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
443
467
Number of Resnet layer for each block
444
- norm_num_groups (:obj:`int`, *optional*, defaults to `2 `):
468
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32 `):
445
469
norm num group
446
470
act_fn (:obj:`str`, *optional*, defaults to `silu`):
447
471
Activation function
@@ -483,6 +507,7 @@ def setup(self):
483
507
in_channels = input_channel ,
484
508
out_channels = output_channel ,
485
509
num_layers = self .layers_per_block ,
510
+ resnet_groups = self .norm_num_groups ,
486
511
add_downsample = not is_final_block ,
487
512
dtype = self .dtype ,
488
513
)
@@ -491,12 +516,15 @@ def setup(self):
491
516
492
517
# middle
493
518
self .mid_block = FlaxUNetMidBlock2D (
494
- in_channels = block_out_channels [- 1 ], attn_num_head_channels = None , dtype = self .dtype
519
+ in_channels = block_out_channels [- 1 ],
520
+ resnet_groups = self .norm_num_groups ,
521
+ attn_num_head_channels = None ,
522
+ dtype = self .dtype ,
495
523
)
496
524
497
525
# end
498
526
conv_out_channels = 2 * self .out_channels if self .double_z else self .out_channels
499
- self .conv_norm_out = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
527
+ self .conv_norm_out = nn .GroupNorm (num_groups = self . norm_num_groups , epsilon = 1e-6 )
500
528
self .conv_out = nn .Conv (
501
529
conv_out_channels ,
502
530
kernel_size = (3 , 3 ),
@@ -581,7 +609,10 @@ def setup(self):
581
609
582
610
# middle
583
611
self .mid_block = FlaxUNetMidBlock2D (
584
- in_channels = block_out_channels [- 1 ], attn_num_head_channels = None , dtype = self .dtype
612
+ in_channels = block_out_channels [- 1 ],
613
+ resnet_groups = self .norm_num_groups ,
614
+ attn_num_head_channels = None ,
615
+ dtype = self .dtype ,
585
616
)
586
617
587
618
# upsampling
@@ -594,10 +625,11 @@ def setup(self):
594
625
595
626
is_final_block = i == len (block_out_channels ) - 1
596
627
597
- up_block = FlaxUpEncoderBlock2D (
628
+ up_block = FlaxUpDecoderBlock2D (
598
629
in_channels = prev_output_channel ,
599
630
out_channels = output_channel ,
600
631
num_layers = self .layers_per_block + 1 ,
632
+ resnet_groups = self .norm_num_groups ,
601
633
add_upsample = not is_final_block ,
602
634
dtype = self .dtype ,
603
635
)
@@ -607,7 +639,7 @@ def setup(self):
607
639
self .up_blocks = up_blocks
608
640
609
641
# end
610
- self .conv_norm_out = nn .GroupNorm (num_groups = 32 , epsilon = 1e-6 )
642
+ self .conv_norm_out = nn .GroupNorm (num_groups = self . norm_num_groups , epsilon = 1e-6 )
611
643
self .conv_out = nn .Conv (
612
644
self .out_channels ,
613
645
kernel_size = (3 , 3 ),
0 commit comments