Skip to content

Commit 8c6bf93

Browse files
authored
Remove flex_attn configs from JobConfig (#1111)
There are conflicts between JobConfig and ModelArgs. Specfically, if we let ModelArgs arguments be in JobConfig, then users have to control these arguments via toml files or command line arguments. However, for some flext_attn configurations, the requirement doesn't make sense as some models have the desired settings. This PR removes these configurations from JobConfig and uses model flavor to control whether flex_attn should be enabled or not.
1 parent 738000f commit 8c6bf93

File tree

9 files changed

+49
-51
lines changed

9 files changed

+49
-51
lines changed

tests/integration_tests.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,7 @@ def build_test_list():
322322
[
323323
"--parallelism.data_parallel_shard_degree=4",
324324
"--activation_checkpoint.mode='full'",
325-
"--model.use_flex_attn",
326-
"--model.attn_mask_type='block_causal'",
325+
"--model.flavor=debugmodel_flex_attn",
327326
]
328327
],
329328
"FSDP+FLEX_ATTN",

torchtitan/config_manager.py

-19
Original file line numberDiff line numberDiff line change
@@ -193,25 +193,6 @@ def __init__(self):
193193
choices=["layernorm", "np_layernorm", "rmsnorm"],
194194
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm]",
195195
)
196-
self.parser.add_argument(
197-
"--model.use_flex_attn",
198-
action="store_true",
199-
help="""
200-
Whether to use Flex Attention.
201-
Mixed usage of SDPA and FlexAttention is not upported yet.
202-
""",
203-
)
204-
self.parser.add_argument(
205-
"--model.attn_mask_type",
206-
type=str,
207-
default="causal",
208-
choices=["causal", "block_causal"],
209-
help="""
210-
Specifies the type of bias/mask used for attention. If SDPA is used,
211-
only the causal mask is supported by default. If FlexAttention is used,
212-
both causal and block_causal masks are supported.
213-
""",
214-
)
215196
self.parser.add_argument(
216197
"--model.tokenizer_path",
217198
type=str,

torchtitan/experiments/llama4/__init__.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
n_layers=6,
3030
n_heads=16,
3131
rope_theta=500000,
32-
every_n_layers_nope=4,
33-
fixed_attn_block_size=256,
3432
),
3533
"17bx16e": TransformerModelArgs(
3634
dim=5120,
@@ -53,6 +51,16 @@
5351
rope_theta=500000,
5452
num_experts=128,
5553
),
54+
"debugmodel_irope": TransformerModelArgs(
55+
dim=256,
56+
n_layers=6,
57+
n_heads=16,
58+
rope_theta=500000,
59+
every_n_layers_nope=4,
60+
fixed_attn_block_size=256,
61+
use_flex_attn=True,
62+
attn_mask_type="block_causal",
63+
),
5664
"17bx16e_irope": TransformerModelArgs(
5765
dim=5120,
5866
n_layers=48,
@@ -64,6 +72,8 @@
6472
num_experts=16,
6573
interleave_moe_layer_step=1,
6674
every_n_layers_nope=4,
75+
use_flex_attn=True,
76+
attn_mask_type="block_causal",
6777
),
6878
"17bx128e_irope": TransformerModelArgs(
6979
dim=5120,
@@ -75,6 +85,8 @@
7585
rope_theta=500000,
7686
num_experts=128,
7787
every_n_layers_nope=4,
88+
use_flex_attn=True,
89+
attn_mask_type="block_causal",
7890
),
7991
}
8092

torchtitan/experiments/llama4/infra/parallelize_llama.py

-8
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,6 @@ def parallelize_llama(
6565
apply_moe_tp(model, world_mesh["tp"])
6666

6767
if job_config.activation_checkpoint.mode != "none":
68-
if (
69-
job_config.activation_checkpoint.mode == "selective"
70-
and job_config.model.use_flex_attn
71-
):
72-
raise ValueError(
73-
"FlexAttention is not compatible with selective AC yet. "
74-
"See https://github.com/pytorch/pytorch/issues/147879"
75-
)
7668
apply_ac(model, job_config.activation_checkpoint)
7769

7870
# turn on per-TransformerBlock compile after AC wrapping and before FSDP

torchtitan/experiments/llama4/model/args.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class TransformerModelArgs(BaseModelArgs):
3434
depth_init: bool = True
3535
norm_type: str = "rmsnorm"
3636

37-
use_flex_attn: bool = True
38-
attn_mask_type: str = "block_causal"
37+
use_flex_attn: bool = False
38+
attn_mask_type: str = "causal"
3939
eos_id: int = 0
4040
# iRoPE settings
4141
# When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is
@@ -62,13 +62,24 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
6262
self.norm_type = job_config.model.norm_type
6363
self.vocab_size = tokenizer.n_words
6464
self.max_seq_len = job_config.training.seq_len
65-
self.use_flex_attn = job_config.model.use_flex_attn
6665
if self.use_grouped_mm and not has_cuda_capability(9, 0):
6766
logger.warning(
6867
"Failed to use grouped mm, which is only supported on SM90 or later",
6968
)
7069
self.use_grouped_mm = False
7170

71+
if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn:
72+
raise ValueError(
73+
"FlexAttention is not compatible with selective AC yet. "
74+
"See https://github.com/pytorch/pytorch/issues/147879"
75+
)
76+
77+
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
78+
raise ValueError(
79+
"FlexAttention is not compatible with CP yet. "
80+
"We are still working on this."
81+
)
82+
7283
def get_nparams_and_flops(
7384
self, model: nn.Module, seq_len: int
7485
) -> tuple[int, float]:

torchtitan/experiments/llama4/train_configs/debug_model.toml

-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
2525
# test tokenizer.model, for debug purpose only
2626
tokenizer_path = "./tests/assets/test_tiktoken.model"
2727
# converters = "float8"
28-
use_flex_attn = false
29-
attn_mask_type = "causal" # causal / block_causal
3028

3129
[optimizer]
3230
# TODO: currently grouped mm in MoE doesn't work with AdamW, need to investigate

torchtitan/models/llama3/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
"debugmodel": TransformerModelArgs(
3131
dim=256, n_layers=6, n_heads=16, rope_theta=500000
3232
),
33+
"debugmodel_flex_attn": TransformerModelArgs(
34+
dim=256,
35+
n_layers=6,
36+
n_heads=16,
37+
rope_theta=500000,
38+
use_flex_attn=True,
39+
attn_mask_type="block_causal",
40+
),
3341
"8B": TransformerModelArgs(
3442
dim=4096,
3543
n_layers=32,

torchtitan/models/llama3/model.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,18 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
4747
self.norm_type = job_config.model.norm_type
4848
self.vocab_size = tokenizer.n_words
4949
self.max_seq_len = job_config.training.seq_len
50-
self.use_flex_attn = job_config.model.use_flex_attn
51-
self.attn_mask_type = job_config.model.attn_mask_type
50+
51+
if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn:
52+
raise ValueError(
53+
"FlexAttention is not compatible with selective AC yet. "
54+
"See https://github.com/pytorch/pytorch/issues/147879"
55+
)
56+
57+
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
58+
raise ValueError(
59+
"FlexAttention is not compatible with CP yet. "
60+
"We are still working on this."
61+
)
5262

5363
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
5464
nparams = sum(p.numel() for p in model.parameters())

torchtitan/models/llama3/parallelize_llama.py

-13
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,6 @@ def parallelize_llama(
7272
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
7373
)
7474

75-
if job_config.model.use_flex_attn:
76-
if job_config.activation_checkpoint.mode == "selective":
77-
raise ValueError(
78-
"FlexAttention is not compatible with selective AC yet. "
79-
"See https://github.com/pytorch/pytorch/issues/147879"
80-
)
81-
82-
if parallel_dims.cp_enabled:
83-
raise ValueError(
84-
"FlexAttention is not compatible with CP yet. "
85-
"We are still working on this."
86-
)
87-
8875
if job_config.activation_checkpoint.mode != "none":
8976
apply_ac(model, job_config.activation_checkpoint)
9077

0 commit comments

Comments
 (0)