@@ -21,9 +21,7 @@ def forward(self, x):
21
21
22
22
23
23
class LNLinearActivationModel (nn .Module ):
24
- def __init__ (
25
- self , fc_dim1 , fc_dim2 , dtype = torch .bfloat16 , activation = "sigmoid" , device = None
26
- ):
24
+ def __init__ (self , fc_dim1 , fc_dim2 , dtype = torch .bfloat16 , activation = "sigmoid" ):
27
25
super ().__init__ ()
28
26
29
27
activation = activation .lower ()
@@ -41,7 +39,7 @@ def __init__(
41
39
raise ValueError (f"Unsupported activation: { activation } " )
42
40
43
41
self .ln = nn .LayerNorm (fc_dim1 , elementwise_affine = False )
44
- self .fc = nn .Linear (fc_dim1 , fc_dim2 , bias = False ).to (dtype = dtype , device = device )
42
+ self .fc = nn .Linear (fc_dim1 , fc_dim2 , bias = False ).to (dtype = dtype )
45
43
self .activation = activation_map [activation ]
46
44
47
45
def forward (self , x ):
@@ -50,6 +48,20 @@ def forward(self, x):
50
48
return self .activation (x )
51
49
52
50
51
+ class RMSNorm (nn .Module ):
52
+ def __init__ (self , dim : int , eps : float = 1e-5 ):
53
+ super ().__init__ ()
54
+ self .eps = eps
55
+ self .weight = nn .Parameter (torch .ones (dim ))
56
+
57
+ def _norm (self , x ):
58
+ return x * torch .rsqrt (torch .mean (x * x , dim = - 1 , keepdim = True ) + self .eps )
59
+
60
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
61
+ output = self ._norm (x .float ()).type_as (x )
62
+ return output * self .weight
63
+
64
+
53
65
class TransformerBlock (torch .nn .Module ):
54
66
def __init__ (self , hidden_dim , num_heads = 8 , mlp_ratio = 4 , dtype = torch .bfloat16 ):
55
67
super ().__init__ ()
@@ -72,8 +84,8 @@ def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16):
72
84
)
73
85
74
86
# Layer norms
75
- self .norm1 = nn . RMSNorm (hidden_dim , dtype = dtype )
76
- self .norm2 = nn . RMSNorm (hidden_dim , dtype = dtype )
87
+ self .norm1 = RMSNorm (hidden_dim ). to ( dtype )
88
+ self .norm2 = RMSNorm (hidden_dim ). to ( dtype )
77
89
78
90
# Activation
79
91
self .activation = torch .nn .GELU ()
0 commit comments