Skip to content

Commit ba4348d

Browse files
authored
[Tests] Improve transformers model test suite coverage - Lumina (#8987)
* Added test suite for lumina * Fixed failing tests * Improved code quality * Added function docstrings * Improved formatting
1 parent d25eb5d commit ba4348d

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import LuminaNextDiT2DModel
21+
from diffusers.utils.testing_utils import (
22+
enable_full_determinism,
23+
torch_device,
24+
)
25+
26+
from ..test_modeling_common import ModelTesterMixin
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class LuminaNextDiT2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
33+
model_class = LuminaNextDiT2DModel
34+
main_input_name = "hidden_states"
35+
36+
@property
37+
def dummy_input(self):
38+
"""
39+
Args:
40+
None
41+
Returns:
42+
Dict: Dictionary of dummy input tensors
43+
"""
44+
batch_size = 2 # N
45+
num_channels = 4 # C
46+
height = width = 16 # H, W
47+
embedding_dim = 32 # D
48+
sequence_length = 16 # L
49+
50+
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
51+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
52+
timestep = torch.rand(size=(batch_size,)).to(torch_device)
53+
encoder_mask = torch.randn(size=(batch_size, sequence_length)).to(torch_device)
54+
image_rotary_emb = torch.randn((384, 384, 4)).to(torch_device)
55+
56+
return {
57+
"hidden_states": hidden_states,
58+
"encoder_hidden_states": encoder_hidden_states,
59+
"timestep": timestep,
60+
"encoder_mask": encoder_mask,
61+
"image_rotary_emb": image_rotary_emb,
62+
"cross_attention_kwargs": {},
63+
}
64+
65+
@property
66+
def input_shape(self):
67+
"""
68+
Args:
69+
None
70+
Returns:
71+
Tuple: (int, int, int)
72+
"""
73+
return (4, 16, 16)
74+
75+
@property
76+
def output_shape(self):
77+
"""
78+
Args:
79+
None
80+
Returns:
81+
Tuple: (int, int, int)
82+
"""
83+
return (4, 16, 16)
84+
85+
def prepare_init_args_and_inputs_for_common(self):
86+
"""
87+
Args:
88+
None
89+
90+
Returns:
91+
Tuple: (Dict, Dict)
92+
"""
93+
init_dict = {
94+
"sample_size": 16,
95+
"patch_size": 2,
96+
"in_channels": 4,
97+
"hidden_size": 24,
98+
"num_layers": 2,
99+
"num_attention_heads": 3,
100+
"num_kv_heads": 1,
101+
"multiple_of": 16,
102+
"ffn_dim_multiplier": None,
103+
"norm_eps": 1e-5,
104+
"learn_sigma": False,
105+
"qk_norm": True,
106+
"cross_attention_dim": 32,
107+
"scaling_factor": 1.0,
108+
}
109+
110+
inputs_dict = self.dummy_input
111+
return init_dict, inputs_dict

0 commit comments

Comments
 (0)