Skip to content

Commit 5078e92

Browse files
authored
deprecate norm.py and norm_type from JobConfig (#1115)
`norm.py` and `norm_type` in `JobConfig` and llama `ModelArgs` were introduced when `nn.RMSNorm` was not available. Now that we don't have such need, let's remove them, following #1111.
1 parent 8c6bf93 commit 5078e92

File tree

15 files changed

+29
-118
lines changed

15 files changed

+29
-118
lines changed

scripts/estimate/estimation.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ def estimate_memory(job_config: JobConfig):
3434
# Get the world size
3535
world_size = int(os.environ["WORLD_SIZE"])
3636

37-
if job_config.model.norm_type == "compiled_rmsnorm":
38-
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
39-
job_config.model.norm_type = "rmsnorm"
40-
4137
if job_config.training.compile or job_config.parallelism.enable_compiled_autograd:
4238
logger.info("Compile mode is not supported yet. Switching to eager mode.")
4339
job_config.training.compile = False
@@ -91,25 +87,19 @@ def estimate_memory(job_config: JobConfig):
9187

9288
# build model (using meta init)
9389
model_cls = train_spec.cls
94-
model_config = train_spec.config[job_config.model.flavor]
95-
# set the model configs from training inputs:
96-
# 1. norm type to decide which norm layer to use
97-
# 2. vocab size from tokenizer
98-
# 3. max_seq_len base on inputs
99-
model_config.norm_type = job_config.model.norm_type
100-
model_config.vocab_size = tokenizer.n_words
101-
model_config.max_seq_len = job_config.training.seq_len
90+
model_args = train_spec.config[job_config.model.flavor]
91+
model_args.update_from_config(job_config, tokenizer)
10292

10393
with (
10494
FakeTensorMode()
10595
if not job_config.memory_estimation.disable_fake_mode
10696
else contextlib.nullcontext()
10797
):
10898
logger.info(
109-
f"Building {train_spec.name} {job_config.model.flavor} with {model_config}"
99+
f"Building {train_spec.name} {job_config.model.flavor} with {model_args}"
110100
)
111101
with torch.device("meta"):
112-
model = model_cls.from_model_args(model_config)
102+
model = model_cls.from_model_args(model_args)
113103

114104
# Build the collection of model converters. No-op if `model.converters` empty
115105
model_converters = build_model_converters(job_config, parallel_dims)
@@ -134,19 +124,19 @@ def estimate_memory(job_config: JobConfig):
134124
lambda *args, **kwargs: model_converters.post_optimizer_hook(model)
135125
)
136126

137-
logger.info(f"Vocab size: {model_config.vocab_size}")
127+
logger.info(f"Vocab size: {model_args.vocab_size}")
138128
# Create a dummy batch instead of loading from a dataset
139129
batch = (
140130
torch.randint(
141131
0,
142-
model_config.vocab_size,
143-
(job_config.training.batch_size, model_config.max_seq_len),
132+
model_args.vocab_size,
133+
(job_config.training.batch_size, model_args.max_seq_len),
144134
device="cuda",
145135
),
146136
torch.randint(
147137
0,
148-
model_config.vocab_size,
149-
(job_config.training.batch_size, model_config.max_seq_len),
138+
model_args.vocab_size,
139+
(job_config.training.batch_size, model_args.max_seq_len),
150140
device="cuda",
151141
),
152142
)

scripts/generate/test_generate.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def test_generate(
8585
color = utils.Color
8686

8787
# Load configuration from toml file
88-
config = JobConfig()
89-
config.parse_args([f"--job.config_file={config_path}"])
90-
config._validate_config()
88+
job_config = JobConfig()
89+
job_config.parse_args([f"--job.config_file={config_path}"])
90+
job_config._validate_config()
9191

9292
if len(args.prompt) == 0:
9393
logger.warning(
@@ -100,27 +100,26 @@ def test_generate(
100100
device_module.set_device(device)
101101
device_memory_monitor = build_device_memory_monitor()
102102

103-
train_spec = get_train_spec(config.model.name)
103+
train_spec = get_train_spec(job_config.model.name)
104104

105105
logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")
106106

107107
# Tokenizer setup
108-
tokenizer = train_spec.build_tokenizer_fn(config)
109-
model_config = train_spec.config[config.model.flavor]
110-
model_config.norm_type = config.model.norm_type
111-
model_config.max_seq_len = config.training.seq_len
112-
model_config.vocab_size = tokenizer.n_words
108+
tokenizer = train_spec.build_tokenizer_fn(job_config)
113109

114110
model_cls = train_spec.cls
111+
model_args = train_spec.config[job_config.model.flavor]
112+
model_args.update_from_config(job_config, tokenizer)
113+
115114
init_device = "meta" if world_size > 1 else device
116115
with torch.device(init_device):
117116
logger.info(f"Init model on init_device: {init_device}")
118-
model = model_cls.from_model_args(model_config)
117+
model = model_cls.from_model_args(model_args)
119118

120119
world_mesh = None
121120
# Init distributed env
122121
if world_size > 1:
123-
dist_utils.init_distributed(config)
122+
dist_utils.init_distributed(job_config)
124123
parallel_dims = ParallelDims(
125124
dp_replicate=1,
126125
dp_shard=-1,

torchtitan/config_manager.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,6 @@ def __init__(self):
186186
default="debugmodel",
187187
help="Which model config to train",
188188
)
189-
self.parser.add_argument(
190-
"--model.norm_type",
191-
type=str,
192-
default="rmsnorm",
193-
choices=["layernorm", "np_layernorm", "rmsnorm"],
194-
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm]",
195-
)
196189
self.parser.add_argument(
197190
"--model.tokenizer_path",
198191
type=str,

torchtitan/experiments/llama4/model/args.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ class TransformerModelArgs(BaseModelArgs):
3232
# If `True`, then each transformer block init uses its layer ID, and if
3333
# `False`, each uses the total number of transformer blocks
3434
depth_init: bool = True
35-
norm_type: str = "rmsnorm"
3635

3736
use_flex_attn: bool = False
3837
attn_mask_type: str = "causal"
@@ -59,7 +58,6 @@ class TransformerModelArgs(BaseModelArgs):
5958
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
6059

6160
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
62-
self.norm_type = job_config.model.norm_type
6361
self.vocab_size = tokenizer.n_words
6462
self.max_seq_len = job_config.training.seq_len
6563
if self.use_grouped_mm and not has_cuda_capability(9, 0):

torchtitan/experiments/llama4/model/model.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch import nn
1111

1212
from torchtitan.models.attention import build_attention, init_attention_mask
13-
from torchtitan.models.norms import build_norm
1413
from torchtitan.protocols.train_spec import ModelProtocol
1514

1615
from .args import TransformerModelArgs
@@ -311,20 +310,13 @@ def __init__(
311310
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
312311
)
313312

314-
self.layer_id = layer_id
315-
self.num_layers = model_args.n_layers
316-
317-
self.attention_norm = build_norm(
318-
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
319-
)
320-
self.ffn_norm = build_norm(
321-
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
322-
)
313+
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
314+
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
323315

324316
if model_args.depth_init:
325-
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
317+
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
326318
else:
327-
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
319+
self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5
328320

329321
def forward(
330322
self,
@@ -399,11 +391,7 @@ def __init__(self, model_args: TransformerModelArgs):
399391
self.layers = torch.nn.ModuleDict()
400392
for layer_id in range(model_args.n_layers):
401393
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
402-
403-
self.norm = build_norm(
404-
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
405-
)
406-
394+
self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
407395
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
408396
self.init_weights()
409397

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ enable_wandb = false
2121
[model]
2222
name = "llama4"
2323
flavor = "debugmodel"
24-
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
2524
# test tokenizer.model, for debug purpose only
2625
tokenizer_path = "./tests/assets/test_tiktoken.model"
2726
# converters = "float8"

torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ save_tb_folder = "tb"
1717
[model]
1818
name = "llama4"
1919
flavor = "17bx128e"
20-
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
2120
tokenizer_path = "./assets/tokenizer/tokenizer.model"
2221
# converters = "float8"
2322

torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ save_tb_folder = "tb"
1717
[model]
1818
name = "llama4"
1919
flavor = "17bx16e"
20-
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
2120
tokenizer_path = "./assets/tokenizer/tokenizer.model"
2221
# converters = "float8"
2322

torchtitan/experiments/multimodal/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ class ModelArgs:
5656
# If `True`, then each transformer block init uses its layer ID, and if
5757
# `False`, each uses the total number of transformer blocks
5858
depth_init: bool = True
59-
norm_type: str = "rmsnorm"
6059

6160

6261
class Fp32LayerNorm(nn.LayerNorm):

torchtitan/models/llama3/model.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from torchtitan.components.tokenizer import Tokenizer
1818
from torchtitan.config_manager import JobConfig
1919
from torchtitan.models.attention import build_attention, init_attention_mask
20-
from torchtitan.models.norms import build_norm
2120
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
2221

2322

@@ -37,14 +36,12 @@ class TransformerModelArgs(BaseModelArgs):
3736
# If `True`, then each transformer block init uses its layer ID, and if
3837
# `False`, each uses the total number of transformer blocks
3938
depth_init: bool = True
40-
norm_type: str = "rmsnorm"
4139

4240
use_flex_attn: bool = False
4341
attn_mask_type: str = "causal"
4442
eos_id: int = 0
4543

4644
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
47-
self.norm_type = job_config.model.norm_type
4845
self.vocab_size = tokenizer.n_words
4946
self.max_seq_len = job_config.training.seq_len
5047

@@ -341,20 +338,13 @@ def __init__(self, layer_id: int, model_args: TransformerModelArgs):
341338
multiple_of=model_args.multiple_of,
342339
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
343340
)
344-
self.layer_id = layer_id
345-
self.num_layers = model_args.n_layers
346-
347-
self.attention_norm = build_norm(
348-
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
349-
)
350-
self.ffn_norm = build_norm(
351-
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
352-
)
341+
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
342+
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
353343

354344
if model_args.depth_init:
355-
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
345+
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
356346
else:
357-
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
347+
self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5
358348

359349
def forward(
360350
self,
@@ -423,11 +413,7 @@ def __init__(self, model_args: TransformerModelArgs):
423413
self.layers = torch.nn.ModuleDict()
424414
for layer_id in range(model_args.n_layers):
425415
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
426-
427-
self.norm = build_norm(
428-
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
429-
)
430-
416+
self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
431417
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
432418
self.init_weights()
433419

0 commit comments

Comments
 (0)