Skip to content

Commit

Permalink
Fixed torch.nn.Module cannot be assigned as str in SchNet (pyg-…
Browse files Browse the repository at this point in the history
  • Loading branch information
EdisonLeeeee authored Dec 5, 2023
1 parent 44c9c12 commit 1e92ba2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch_geometric/nn/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def from_qm9_pretrained(
net.lin2.bias = state.output_modules[0].out_net[1].out_net[1].bias

mean = state.output_modules[0].atom_pool.average
net.readout = 'mean' if mean is True else 'add'
net.readout = aggr_resolver('mean' if mean is True else 'add')

dipole = state.output_modules[0].__class__.__name__ == 'DipoleMoment'
net.dipole = dipole
Expand Down

0 comments on commit 1e92ba2

Please sign in to comment.