Skip to content

Commit 4f496a4

Browse files
committed
Rename use_static_masks -> optimize_for_inference
1 parent 59cfc49 commit 4f496a4

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def __init__(
344344
shared_attn_ids = None,
345345
shared_ff_ids = None,
346346
share_input_output_emb = False,
347-
use_static_masks = False,
347+
optimize_for_inference = False,
348348
):
349349
super().__init__()
350350
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
@@ -392,7 +392,7 @@ def __init__(
392392
rotary_emb = rotary_emb,
393393
shared_attn_ids = shared_attn_ids,
394394
shared_ff_ids = shared_ff_ids,
395-
use_static_masks = use_static_masks,
395+
optimize_for_inference = optimize_for_inference,
396396
)
397397

398398
self.stable = stable

dalle_pytorch/transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def __init__(
220220
rotary_emb = True,
221221
shared_attn_ids = None,
222222
shared_ff_ids = None,
223-
use_static_masks = False,
223+
optimize_for_inference = False,
224224
):
225225
super().__init__()
226226
layers = nn.ModuleList([])
@@ -245,12 +245,12 @@ def __init__(
245245
elif attn_type == 'sparse':
246246
attn_class = SparseAttention
247247
elif attn_type == 'axial_row':
248-
if use_static_masks:
248+
if optimize_for_inference:
249249
attn_class = partial(Attention, stable = stable, static_mask = self._get_static_mask(attn_type))
250250
else:
251251
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
252252
elif attn_type == 'axial_col':
253-
if use_static_masks:
253+
if optimize_for_inference:
254254
attn_class = partial(Attention, stable = stable, static_mask = self._get_static_mask(attn_type))
255255
else:
256256
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)

0 commit comments

Comments
 (0)