@@ -21,9 +21,7 @@ def forward(self, x):
2121
2222
2323class LNLinearActivationModel (nn .Module ):
24- def __init__ (
25- self , fc_dim1 , fc_dim2 , dtype = torch .bfloat16 , activation = "sigmoid" , device = None
26- ):
24+ def __init__ (self , fc_dim1 , fc_dim2 , dtype = torch .bfloat16 , activation = "sigmoid" ):
2725 super ().__init__ ()
2826
2927 activation = activation .lower ()
@@ -41,7 +39,7 @@ def __init__(
4139 raise ValueError (f"Unsupported activation: { activation } " )
4240
4341 self .ln = nn .LayerNorm (fc_dim1 , elementwise_affine = False )
44- self .fc = nn .Linear (fc_dim1 , fc_dim2 , bias = False ).to (dtype = dtype , device = device )
42+ self .fc = nn .Linear (fc_dim1 , fc_dim2 , bias = False ).to (dtype = dtype )
4543 self .activation = activation_map [activation ]
4644
4745 def forward (self , x ):
@@ -50,6 +48,20 @@ def forward(self, x):
5048 return self .activation (x )
5149
5250
51+ class RMSNorm (nn .Module ):
52+ def __init__ (self , dim : int , eps : float = 1e-5 ):
53+ super ().__init__ ()
54+ self .eps = eps
55+ self .weight = nn .Parameter (torch .ones (dim ))
56+
57+ def _norm (self , x ):
58+ return x * torch .rsqrt (torch .mean (x * x , dim = - 1 , keepdim = True ) + self .eps )
59+
60+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
61+ output = self ._norm (x .float ()).type_as (x )
62+ return output * self .weight
63+
64+
5365class TransformerBlock (torch .nn .Module ):
5466 def __init__ (self , hidden_dim , num_heads = 8 , mlp_ratio = 4 , dtype = torch .bfloat16 ):
5567 super ().__init__ ()
@@ -72,8 +84,8 @@ def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16):
7284 )
7385
7486 # Layer norms
75- self .norm1 = nn . RMSNorm (hidden_dim , dtype = dtype )
76- self .norm2 = nn . RMSNorm (hidden_dim , dtype = dtype )
87+ self .norm1 = RMSNorm (hidden_dim ). to ( dtype )
88+ self .norm2 = RMSNorm (hidden_dim ). to ( dtype )
7789
7890 # Activation
7991 self .activation = torch .nn .GELU ()
0 commit comments