Skip to content

Commit a28b4d4

Browse files
committed
fix(loss): 修正DOSLoss中的损失计算方式
将平方损失改为直接使用L1损失,因为MSE损失等价于自加权的L1损失
1 parent 1792883 commit a28b4d4

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

dptb/nnops/loss.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,11 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
462462
dosdft, ddosdft, d2dosdft = DOSLoss.calc_dos(
463463
eigvaldft, wk, emin, emax,
464464
de=self.de, sigma=self.degauss, with_derivatives=True)
465-
# the loss is the MSE between two DOS (self-weighted)
466-
loss += self.loss(dostb, dosdft )**2 + \
467-
self.loss(ddostb, ddosdft )**2 + \
468-
self.loss(d2dostb, d2dosdft)**2
465+
# the mse loss is equivalent with the L1Loss with self-weighted, so it is good.
466+
# loss += |dostb - dosdft|^2 + |ddostb - ddosdft|^2 + |d2dostb - d2dosdft|^2
467+
loss += self.loss(dostb, dosdft ) + \
468+
self.loss(ddostb, ddosdft ) + \
469+
self.loss(d2dostb, d2dosdft)
469470
return loss
470471

471472
# @Loss.register("hamil")

0 commit comments

Comments
 (0)