Skip to content

Commit a124204

Browse files
authored
Flax: Trickle down norm_num_groups (#789)
* pass norm_num_groups param and add tests * set resnet_groups for FlaxUNetMidBlock2D * fixed docstrings * fixed typo * using is_flax_available util and created require_flax decorator
1 parent 66a5279 commit a124204

File tree

4 files changed

+138
-14
lines changed

4 files changed

+138
-14
lines changed

src/diffusers/models/vae_flax.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module):
119119
Output channels
120120
dropout (:obj:`float`, *optional*, defaults to 0.0):
121121
Dropout rate
122+
groups (:obj:`int`, *optional*, defaults to `32`):
123+
The number of groups to use for group norm.
122124
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
123125
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
124126
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module):
128130
in_channels: int
129131
out_channels: int = None
130132
dropout: float = 0.0
133+
groups: int = 32
131134
use_nin_shortcut: bool = None
132135
dtype: jnp.dtype = jnp.float32
133136

134137
def setup(self):
135138
out_channels = self.in_channels if self.out_channels is None else self.out_channels
136139

137-
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
140+
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
138141
self.conv1 = nn.Conv(
139142
out_channels,
140143
kernel_size=(3, 3),
@@ -143,7 +146,7 @@ def setup(self):
143146
dtype=self.dtype,
144147
)
145148

146-
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
149+
self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
147150
self.dropout_layer = nn.Dropout(self.dropout)
148151
self.conv2 = nn.Conv(
149152
out_channels,
@@ -191,20 +194,23 @@ class FlaxAttentionBlock(nn.Module):
191194
Input channels
192195
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
193196
Number of attention heads
197+
num_groups (:obj:`int`, *optional*, defaults to `32`):
198+
The number of groups to use for group norm
194199
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
195200
Parameters `dtype`
196201
197202
"""
198203
channels: int
199204
num_head_channels: int = None
205+
num_groups: int = 32
200206
dtype: jnp.dtype = jnp.float32
201207

202208
def setup(self):
203209
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
204210

205211
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
206212

207-
self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
213+
self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
208214
self.query, self.key, self.value = dense(), dense(), dense()
209215
self.proj_attn = dense()
210216

@@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
264270
Dropout rate
265271
num_layers (:obj:`int`, *optional*, defaults to 1):
266272
Number of Resnet layer block
273+
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
274+
The number of groups to use for the Resnet block group norm
267275
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
268276
Whether to add downsample layer
269277
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
273281
out_channels: int
274282
dropout: float = 0.0
275283
num_layers: int = 1
284+
resnet_groups: int = 32
276285
add_downsample: bool = True
277286
dtype: jnp.dtype = jnp.float32
278287

@@ -285,6 +294,7 @@ def setup(self):
285294
in_channels=in_channels,
286295
out_channels=self.out_channels,
287296
dropout=self.dropout,
297+
groups=self.resnet_groups,
288298
dtype=self.dtype,
289299
)
290300
resnets.append(res_block)
@@ -303,9 +313,9 @@ def __call__(self, hidden_states, deterministic=True):
303313
return hidden_states
304314

305315

306-
class FlaxUpEncoderBlock2D(nn.Module):
316+
class FlaxUpDecoderBlock2D(nn.Module):
307317
r"""
308-
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
318+
Flax Resnet blocks-based Decoder block for diffusion-based VAE.
309319
310320
Parameters:
311321
in_channels (:obj:`int`):
@@ -316,15 +326,18 @@ class FlaxUpEncoderBlock2D(nn.Module):
316326
Dropout rate
317327
num_layers (:obj:`int`, *optional*, defaults to 1):
318328
Number of Resnet layer block
319-
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
320-
Whether to add downsample layer
329+
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
330+
The number of groups to use for the Resnet block group norm
331+
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
332+
Whether to add upsample layer
321333
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
322334
Parameters `dtype`
323335
"""
324336
in_channels: int
325337
out_channels: int
326338
dropout: float = 0.0
327339
num_layers: int = 1
340+
resnet_groups: int = 32
328341
add_upsample: bool = True
329342
dtype: jnp.dtype = jnp.float32
330343

@@ -336,6 +349,7 @@ def setup(self):
336349
in_channels=in_channels,
337350
out_channels=self.out_channels,
338351
dropout=self.dropout,
352+
groups=self.resnet_groups,
339353
dtype=self.dtype,
340354
)
341355
resnets.append(res_block)
@@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module):
366380
Dropout rate
367381
num_layers (:obj:`int`, *optional*, defaults to 1):
368382
Number of Resnet layer block
383+
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
384+
The number of groups to use for the Resnet and Attention block group norm
369385
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
370386
Number of attention heads for each attention block
371387
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module):
374390
in_channels: int
375391
dropout: float = 0.0
376392
num_layers: int = 1
393+
resnet_groups: int = 32
377394
attn_num_head_channels: int = 1
378395
dtype: jnp.dtype = jnp.float32
379396

380397
def setup(self):
398+
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
399+
381400
# there is always at least one resnet
382401
resnets = [
383402
FlaxResnetBlock2D(
384403
in_channels=self.in_channels,
385404
out_channels=self.in_channels,
386405
dropout=self.dropout,
406+
groups=resnet_groups,
387407
dtype=self.dtype,
388408
)
389409
]
@@ -392,14 +412,18 @@ def setup(self):
392412

393413
for _ in range(self.num_layers):
394414
attn_block = FlaxAttentionBlock(
395-
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
415+
channels=self.in_channels,
416+
num_head_channels=self.attn_num_head_channels,
417+
num_groups=resnet_groups,
418+
dtype=self.dtype,
396419
)
397420
attentions.append(attn_block)
398421

399422
res_block = FlaxResnetBlock2D(
400423
in_channels=self.in_channels,
401424
out_channels=self.in_channels,
402425
dropout=self.dropout,
426+
groups=resnet_groups,
403427
dtype=self.dtype,
404428
)
405429
resnets.append(res_block)
@@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module):
441465
Tuple containing the number of output channels for each block
442466
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
443467
Number of Resnet layer for each block
444-
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
468+
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
445469
norm num group
446470
act_fn (:obj:`str`, *optional*, defaults to `silu`):
447471
Activation function
@@ -483,6 +507,7 @@ def setup(self):
483507
in_channels=input_channel,
484508
out_channels=output_channel,
485509
num_layers=self.layers_per_block,
510+
resnet_groups=self.norm_num_groups,
486511
add_downsample=not is_final_block,
487512
dtype=self.dtype,
488513
)
@@ -491,12 +516,15 @@ def setup(self):
491516

492517
# middle
493518
self.mid_block = FlaxUNetMidBlock2D(
494-
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
519+
in_channels=block_out_channels[-1],
520+
resnet_groups=self.norm_num_groups,
521+
attn_num_head_channels=None,
522+
dtype=self.dtype,
495523
)
496524

497525
# end
498526
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
499-
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
527+
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
500528
self.conv_out = nn.Conv(
501529
conv_out_channels,
502530
kernel_size=(3, 3),
@@ -581,7 +609,10 @@ def setup(self):
581609

582610
# middle
583611
self.mid_block = FlaxUNetMidBlock2D(
584-
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
612+
in_channels=block_out_channels[-1],
613+
resnet_groups=self.norm_num_groups,
614+
attn_num_head_channels=None,
615+
dtype=self.dtype,
585616
)
586617

587618
# upsampling
@@ -594,10 +625,11 @@ def setup(self):
594625

595626
is_final_block = i == len(block_out_channels) - 1
596627

597-
up_block = FlaxUpEncoderBlock2D(
628+
up_block = FlaxUpDecoderBlock2D(
598629
in_channels=prev_output_channel,
599630
out_channels=output_channel,
600631
num_layers=self.layers_per_block + 1,
632+
resnet_groups=self.norm_num_groups,
601633
add_upsample=not is_final_block,
602634
dtype=self.dtype,
603635
)
@@ -607,7 +639,7 @@ def setup(self):
607639
self.up_blocks = up_blocks
608640

609641
# end
610-
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
642+
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
611643
self.conv_out = nn.Conv(
612644
self.out_channels,
613645
kernel_size=(3, 3),

src/diffusers/utils/testing_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import requests
1515
from packaging import version
1616

17+
from .import_utils import is_flax_available
18+
1719

1820
global_rng = random.Random()
1921
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -89,6 +91,13 @@ def slow(test_case):
8991
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
9092

9193

94+
def require_flax(test_case):
95+
"""
96+
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
97+
"""
98+
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
99+
100+
92101
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
93102
"""
94103
Args:

tests/test_modeling_common_flax.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from diffusers.utils import is_flax_available
2+
from diffusers.utils.testing_utils import require_flax
3+
4+
5+
if is_flax_available():
6+
import jax
7+
8+
9+
@require_flax
10+
class FlaxModelTesterMixin:
11+
def test_output(self):
12+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
13+
14+
model = self.model_class(**init_dict)
15+
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
16+
jax.lax.stop_gradient(variables)
17+
18+
output = model.apply(variables, inputs_dict["sample"])
19+
20+
if isinstance(output, dict):
21+
output = output.sample
22+
23+
self.assertIsNotNone(output)
24+
expected_shape = inputs_dict["sample"].shape
25+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
26+
27+
def test_forward_with_norm_groups(self):
28+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
29+
30+
init_dict["norm_num_groups"] = 16
31+
init_dict["block_out_channels"] = (16, 32)
32+
33+
model = self.model_class(**init_dict)
34+
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
35+
jax.lax.stop_gradient(variables)
36+
37+
output = model.apply(variables, inputs_dict["sample"])
38+
39+
if isinstance(output, dict):
40+
output = output.sample
41+
42+
self.assertIsNotNone(output)
43+
expected_shape = inputs_dict["sample"].shape
44+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

tests/test_models_vae_flax.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
3+
from diffusers import FlaxAutoencoderKL
4+
from diffusers.utils import is_flax_available
5+
from diffusers.utils.testing_utils import require_flax
6+
7+
from .test_modeling_common_flax import FlaxModelTesterMixin
8+
9+
10+
if is_flax_available():
11+
import jax
12+
13+
14+
@require_flax
15+
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
16+
model_class = FlaxAutoencoderKL
17+
18+
@property
19+
def dummy_input(self):
20+
batch_size = 4
21+
num_channels = 3
22+
sizes = (32, 32)
23+
24+
prng_key = jax.random.PRNGKey(0)
25+
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
26+
27+
return {"sample": image, "prng_key": prng_key}
28+
29+
def prepare_init_args_and_inputs_for_common(self):
30+
init_dict = {
31+
"block_out_channels": [32, 64],
32+
"in_channels": 3,
33+
"out_channels": 3,
34+
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
35+
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
36+
"latent_channels": 4,
37+
}
38+
inputs_dict = self.dummy_input
39+
return init_dict, inputs_dict

0 commit comments

Comments
 (0)