Skip to content

Commit

Permalink
Rename orbnet to ofmnet
Browse files Browse the repository at this point in the history
  • Loading branch information
Irlirion committed Feb 11, 2022
1 parent 93c559e commit 8eda49b
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ includes:
- src/configs/base.yml

model:
name: orbnet_native
name: ofmnet_native
emb_size_atom: 128
emb_size_edge: 64
cutoff: 6.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
- src: is2res_train_val_test_lmdbs/data/is2re/all/descriptors/val_id/data.mdb

model:
name: orbnet_native_desc
name: ofmnet_native_desc
emb_size_atom: 128
emb_size_edge: 64
cutoff: 6.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
- src: is2res_train_val_test_lmdbs/data/is2re/all/descriptors/val_id/data.mdb

model:
name: orbnet_native_desc
name: ofmnet_native_desc
emb_size_atom: 128
emb_size_edge: 64
cutoff: 6.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
- src: is2res_train_val_test_lmdbs/data/is2re/all/descriptors/val_id/data.mdb

model:
name: orbnet_native_surf
name: ofmnet_native_surf
emb_size_atom: 128
emb_size_edge: 64
cutoff: 6.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
- src: is2res_train_val_test_lmdbs/data/is2re/all/descriptors/val_idh/data.mdb

model:
name: orbnet_native_desc
name: ofmnet_native_desc
emb_size_atom: 128
emb_size_edge: 64
cutoff: 6.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
- src: is2res_train_val_test_lmdbs/data/is2re/all/descriptors/val_id/data.mdb

model:
name: orbnet_native_surf_ofms
name: ofmnet_native_surf_ofms
emb_size_atom: 128
emb_size_edge: 64
cutoff: 6.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
- src: is2res_train_val_test_lmdbs/data/is2re/all/descriptors/val_id/data.mdb

model:
name: orbnet_native_ofms
name: ofmnet_native_ofms
emb_size_atom: 128
emb_size_edge: 64
cutoff: 6.
Expand Down
8 changes: 4 additions & 4 deletions src/models/orbnet_native.py → src/models/ofmnet_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def forward(

return h_enc, e_enc, e_aux

class OrbNet(nn.Module):
class OfmNet(nn.Module):
def __init__(
self,
num_radial: int = 8,
Expand Down Expand Up @@ -313,8 +313,8 @@ def reset_parameters(self):
for output_block in self.output_blocks:
output_block.reset_parameters()

@registry.register_model("orbnet_native")
class OrbNetWrap(OrbNet):
@registry.register_model("ofmnet_native")
class OfmNetWrap(OfmNet):
def __init__(
self,
num_atoms, # not used
Expand Down Expand Up @@ -347,7 +347,7 @@ def __init__(
self.periods = period_and_group["Period"].values
self.groups = period_and_group["Group"].values

super(OrbNetWrap, self).__init__(
super(OfmNetWrap, self).__init__(
num_heads=num_heads,
num_radial=num_radial,
emb_size_atom=emb_size_atom,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def forward(
return h_enc, e_enc, e_aux


class OrbNet(nn.Module):
class OfmNet(nn.Module):
def __init__(
self,
num_radial: int = 8,
Expand Down Expand Up @@ -321,7 +321,7 @@ def reset_parameters(self):


@registry.register_model("orbnet_native_desc")
class OrbNetWrap(OrbNet):
class OfmNetWrap(OfmNet):
def __init__(
self,
num_atoms, # not used
Expand Down Expand Up @@ -359,7 +359,7 @@ def __init__(
self.periods = period_and_group["Period"].values
self.groups = period_and_group["Group"].values

super(OrbNetWrap, self).__init__(
super(OfmNetWrap, self).__init__(
num_heads=num_heads,
num_radial=num_radial,
emb_size_atom=emb_size_atom,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def forward(
return h_enc, e_enc, e_aux


class OrbNet(nn.Module):
class OfmNet(nn.Module):
def __init__(
self,
num_radial: int = 8,
Expand Down Expand Up @@ -311,7 +311,7 @@ def reset_parameters(self):


@registry.register_model("orbnet_native_ofms")
class OrbNetWrap(OrbNet):
class OfmNetWrap(OfmNet):
def __init__(
self,
num_atoms, # not used
Expand Down Expand Up @@ -344,7 +344,7 @@ def __init__(
self.periods = features["Period"].values
self.groups = features["Group"].values

super(OrbNetWrap, self).__init__(
super(OfmNetWrap, self).__init__(
num_heads=num_heads,
num_radial=num_radial,
emb_size_atom=emb_size_atom,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def forward(
return h_enc, e_enc, e_aux


class OrbNet(nn.Module):
class OfmNet(nn.Module):
def __init__(
self,
num_radial: int = 8,
Expand Down Expand Up @@ -307,8 +307,8 @@ def reset_parameters(self):
output_block.reset_parameters()


@registry.register_model("orbnet_native_surf")
class OrbNetWrap(OrbNet):
@registry.register_model("ofmnet_native_surf")
class OfmNetWrap(OfmNet):
def __init__(
self,
num_atoms, # not used
Expand Down Expand Up @@ -345,7 +345,7 @@ def __init__(
self.adsorb_atomic_numbers = (1, 6, 7, 8)
self.cov_coeff = cov_coeff

super(OrbNetWrap, self).__init__(
super(OfmNetWrap, self).__init__(
num_heads=num_heads,
num_radial=num_radial,
emb_size_atom=emb_size_atom,
Expand Down
8 changes: 4 additions & 4 deletions src/models/orbnet_native_surf_ofms.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def forward(
return h_enc, e_enc, e_aux


class OrbNet(nn.Module):
class OfmNet(nn.Module):
def __init__(
self,
num_radial: int = 8,
Expand Down Expand Up @@ -313,8 +313,8 @@ def reset_parameters(self):
output_block.reset_parameters()


@registry.register_model("orbnet_native_surf_ofms")
class OrbNetWrap(OrbNet):
@registry.register_model("ofmnet_native_surf_ofms")
class OfmNetWrap(OfmNet):
def __init__(
self,
num_atoms, # not used
Expand Down Expand Up @@ -351,7 +351,7 @@ def __init__(
self.adsorb_atomic_numbers = (1, 6, 7, 8)
self.cov_coeff = cov_coeff

super(OrbNetWrap, self).__init__(
super(OfmNetWrap, self).__init__(
num_heads=num_heads,
num_radial=num_radial,
emb_size_atom=emb_size_atom,
Expand Down

0 comments on commit 8eda49b

Please sign in to comment.