Skip to content

Commit 000fa82

Browse files
authored
[Chore] remove class assignments for linear and conv. (#7553)
* remove class assignments for linear and conv. * fix: self.nn
1 parent 5d83f50 commit 000fa82

10 files changed

+38
-61
lines changed

src/diffusers/models/attention.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,6 @@ def __init__(
634634
if inner_dim is None:
635635
inner_dim = int(dim * mult)
636636
dim_out = dim_out if dim_out is not None else dim
637-
linear_cls = nn.Linear
638637

639638
if activation_fn == "gelu":
640639
act_fn = GELU(dim, inner_dim, bias=bias)
@@ -651,7 +650,7 @@ def __init__(
651650
# project dropout
652651
self.net.append(nn.Dropout(dropout))
653652
# project out
654-
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
653+
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
655654
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
656655
if final_dropout:
657656
self.net.append(nn.Dropout(dropout))

src/diffusers/models/attention_processor.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -181,25 +181,22 @@ def __init__(
181181
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
182182
)
183183

184-
linear_cls = nn.Linear
185-
186-
self.linear_cls = linear_cls
187-
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
184+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
188185

189186
if not self.only_cross_attention:
190187
# only relevant for the `AddedKVProcessor` classes
191-
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
192-
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
188+
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
189+
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
193190
else:
194191
self.to_k = None
195192
self.to_v = None
196193

197194
if self.added_kv_proj_dim is not None:
198-
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
199-
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
195+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
196+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
200197

201198
self.to_out = nn.ModuleList([])
202-
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
199+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
203200
self.to_out.append(nn.Dropout(dropout))
204201

205202
# set attention processor
@@ -706,7 +703,7 @@ def fuse_projections(self, fuse=True):
706703
out_features = concatenated_weights.shape[0]
707704

708705
# create a new single projection layer and copy over the weights.
709-
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
706+
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
710707
self.to_qkv.weight.copy_(concatenated_weights)
711708
if self.use_bias:
712709
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
@@ -717,7 +714,7 @@ def fuse_projections(self, fuse=True):
717714
in_features = concatenated_weights.shape[1]
718715
out_features = concatenated_weights.shape[0]
719716

720-
self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
717+
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
721718
self.to_kv.weight.copy_(concatenated_weights)
722719
if self.use_bias:
723720
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])

src/diffusers/models/downsampling.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def __init__(
102102
self.padding = padding
103103
stride = 2
104104
self.name = name
105-
conv_cls = nn.Conv2d
106105

107106
if norm_type == "ln_norm":
108107
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -114,7 +113,7 @@ def __init__(
114113
raise ValueError(f"unknown norm_type: {norm_type}")
115114

116115
if use_conv:
117-
conv = conv_cls(
116+
conv = nn.Conv2d(
118117
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
119118
)
120119
else:

src/diffusers/models/embeddings.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,8 @@ def __init__(
199199
sample_proj_bias=True,
200200
):
201201
super().__init__()
202-
linear_cls = nn.Linear
203202

204-
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
203+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
205204

206205
if cond_proj_dim is not None:
207206
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -214,7 +213,7 @@ def __init__(
214213
time_embed_dim_out = out_dim
215214
else:
216215
time_embed_dim_out = time_embed_dim
217-
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
216+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
218217

219218
if post_act_fn is None:
220219
self.post_act = None

src/diffusers/models/resnet.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ def __init__(
101101
self.output_scale_factor = output_scale_factor
102102
self.time_embedding_norm = time_embedding_norm
103103

104-
conv_cls = nn.Conv2d
105-
106104
if groups_out is None:
107105
groups_out = groups
108106

@@ -113,7 +111,7 @@ def __init__(
113111
else:
114112
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
115113

116-
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
114+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
117115

118116
if self.time_embedding_norm == "ada_group": # ada_group
119117
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
@@ -125,7 +123,7 @@ def __init__(
125123
self.dropout = torch.nn.Dropout(dropout)
126124

127125
conv_2d_out_channels = conv_2d_out_channels or out_channels
128-
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
126+
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
129127

130128
self.nonlinearity = get_activation(non_linearity)
131129

@@ -139,7 +137,7 @@ def __init__(
139137

140138
self.conv_shortcut = None
141139
if self.use_in_shortcut:
142-
self.conv_shortcut = conv_cls(
140+
self.conv_shortcut = nn.Conv2d(
143141
in_channels,
144142
conv_2d_out_channels,
145143
kernel_size=1,
@@ -263,21 +261,18 @@ def __init__(
263261
self.time_embedding_norm = time_embedding_norm
264262
self.skip_time_act = skip_time_act
265263

266-
linear_cls = nn.Linear
267-
conv_cls = nn.Conv2d
268-
269264
if groups_out is None:
270265
groups_out = groups
271266

272267
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
273268

274-
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
269+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
275270

276271
if temb_channels is not None:
277272
if self.time_embedding_norm == "default":
278-
self.time_emb_proj = linear_cls(temb_channels, out_channels)
273+
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
279274
elif self.time_embedding_norm == "scale_shift":
280-
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
275+
self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
281276
else:
282277
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
283278
else:
@@ -287,7 +282,7 @@ def __init__(
287282

288283
self.dropout = torch.nn.Dropout(dropout)
289284
conv_2d_out_channels = conv_2d_out_channels or out_channels
290-
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
285+
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
291286

292287
self.nonlinearity = get_activation(non_linearity)
293288

@@ -313,7 +308,7 @@ def __init__(
313308

314309
self.conv_shortcut = None
315310
if self.use_in_shortcut:
316-
self.conv_shortcut = conv_cls(
311+
self.conv_shortcut = nn.Conv2d(
317312
in_channels,
318313
conv_2d_out_channels,
319314
kernel_size=1,

src/diffusers/models/transformers/transformer_2d.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,6 @@ def __init__(
117117
self.attention_head_dim = attention_head_dim
118118
inner_dim = num_attention_heads * attention_head_dim
119119

120-
conv_cls = nn.Conv2d
121-
linear_cls = nn.Linear
122-
123120
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
124121
# Define whether input is continuous or discrete depending on configuration
125122
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
@@ -159,9 +156,9 @@ def __init__(
159156

160157
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
161158
if use_linear_projection:
162-
self.proj_in = linear_cls(in_channels, inner_dim)
159+
self.proj_in = nn.Linear(in_channels, inner_dim)
163160
else:
164-
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
161+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
165162
elif self.is_input_vectorized:
166163
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
167164
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
@@ -222,9 +219,9 @@ def __init__(
222219
if self.is_input_continuous:
223220
# TODO: should use out_channels for continuous projections
224221
if use_linear_projection:
225-
self.proj_out = linear_cls(inner_dim, in_channels)
222+
self.proj_out = nn.Linear(inner_dim, in_channels)
226223
else:
227-
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
224+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
228225
elif self.is_input_vectorized:
229226
self.norm_out = nn.LayerNorm(inner_dim)
230227
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)

src/diffusers/models/unets/unet_stable_cascade.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ def forward(self, x):
4141
class SDCascadeTimestepBlock(nn.Module):
4242
def __init__(self, c, c_timestep, conds=[]):
4343
super().__init__()
44-
linear_cls = nn.Linear
45-
self.mapper = linear_cls(c_timestep, c * 2)
44+
45+
self.mapper = nn.Linear(c_timestep, c * 2)
4646
self.conds = conds
4747
for cname in conds:
48-
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
48+
setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2))
4949

5050
def forward(self, x, t):
5151
t = t.chunk(len(self.conds) + 1, dim=1)
@@ -94,12 +94,11 @@ def forward(self, x):
9494
class SDCascadeAttnBlock(nn.Module):
9595
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
9696
super().__init__()
97-
linear_cls = nn.Linear
9897

9998
self.self_attn = self_attn
10099
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
101100
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
102-
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
101+
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
103102

104103
def forward(self, x, kv):
105104
kv = self.kv_mapper(kv)

src/diffusers/models/upsampling.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def __init__(
110110
self.use_conv_transpose = use_conv_transpose
111111
self.name = name
112112
self.interpolate = interpolate
113-
conv_cls = nn.Conv2d
114113

115114
if norm_type == "ln_norm":
116115
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -131,7 +130,7 @@ def __init__(
131130
elif use_conv:
132131
if kernel_size is None:
133132
kernel_size = 3
134-
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
133+
conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
135134

136135
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
137136
if name == "conv":

src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def forward(self, x):
1717
class TimestepBlock(nn.Module):
1818
def __init__(self, c, c_timestep):
1919
super().__init__()
20-
linear_cls = nn.Linear
21-
self.mapper = linear_cls(c_timestep, c * 2)
20+
21+
self.mapper = nn.Linear(c_timestep, c * 2)
2222

2323
def forward(self, x, t):
2424
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
@@ -29,13 +29,10 @@ class ResBlock(nn.Module):
2929
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
3030
super().__init__()
3131

32-
conv_cls = nn.Conv2d
33-
linear_cls = nn.Linear
34-
35-
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
32+
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
3633
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
3734
self.channelwise = nn.Sequential(
38-
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
35+
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
3936
)
4037

4138
def forward(self, x, x_skip=None):
@@ -64,12 +61,10 @@ class AttnBlock(nn.Module):
6461
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
6562
super().__init__()
6663

67-
linear_cls = nn.Linear
68-
6964
self.self_attn = self_attn
7065
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
7166
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
72-
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
67+
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
7368

7469
def forward(self, x, kv):
7570
kv = self.kv_mapper(kv)

src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
4040
@register_to_config
4141
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
4242
super().__init__()
43-
conv_cls = nn.Conv2d
44-
linear_cls = nn.Linear
4543

4644
self.c_r = c_r
47-
self.projection = conv_cls(c_in, c, kernel_size=1)
45+
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
4846
self.cond_mapper = nn.Sequential(
49-
linear_cls(c_cond, c),
47+
nn.Linear(c_cond, c),
5048
nn.LeakyReLU(0.2),
51-
linear_cls(c, c),
49+
nn.Linear(c, c),
5250
)
5351

5452
self.blocks = nn.ModuleList()
@@ -58,7 +56,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro
5856
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
5957
self.out = nn.Sequential(
6058
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
61-
conv_cls(c, c_in * 2, kernel_size=1),
59+
nn.Conv2d(c, c_in * 2, kernel_size=1),
6260
)
6361

6462
self.gradient_checkpointing = False

0 commit comments

Comments
 (0)