Skip to content

Commit 2ba3100

Browse files
authored
Perf: use F.linear for MLP (#4513)
It brings <1% speedup. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Simplified linear transformation implementation in the neural network layer - Improved code readability and efficiency in matrix operations <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent fdf8049 commit 2ba3100

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

deepmd/pt/model/network/mlp.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010
import torch.nn as nn
11+
import torch.nn.functional as F
1112

1213
from deepmd.pt.utils import (
1314
env,
@@ -202,18 +203,14 @@ def forward(
202203
ori_prec = xx.dtype
203204
if not env.DP_DTYPE_PROMOTION_STRICT:
204205
xx = xx.to(self.prec)
205-
yy = (
206-
torch.matmul(xx, self.matrix) + self.bias
207-
if self.bias is not None
208-
else torch.matmul(xx, self.matrix)
209-
)
210-
yy = self.activate(yy).clone()
206+
yy = F.linear(xx, self.matrix.t(), self.bias)
207+
yy = self.activate(yy)
211208
yy = yy * self.idt if self.idt is not None else yy
212209
if self.resnet:
213210
if xx.shape[-1] == yy.shape[-1]:
214-
yy += xx
211+
yy = yy + xx
215212
elif 2 * xx.shape[-1] == yy.shape[-1]:
216-
yy += torch.concat([xx, xx], dim=-1)
213+
yy = yy + torch.concat([xx, xx], dim=-1)
217214
else:
218215
yy = yy
219216
if not env.DP_DTYPE_PROMOTION_STRICT:

0 commit comments

Comments
 (0)