Skip to content

Commit 62e1661

Browse files
fix: DistanceConfFilter is optimized and the bug in BoxSkewnessConfFilter is fixed. (#292)
The DistanceConfFilter is optimized and the bug in BoxSkewnessConfFilter is fixed. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved accuracy of hydrogen distance checks by updating the safe distance value. - Enhanced structure validation by refining lattice length and skewness checks for better symmetry handling. - **Refactor** - Streamlined distance and lattice checks to operate directly on the original structure, improving efficiency and reliability. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 749c7be commit 62e1661

File tree

1 file changed

+42
-10
lines changed

1 file changed

+42
-10
lines changed

dpgen2/exploration/selector/distance_conf_filter.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
safe_dist_dict = {
24-
"H": 1.2255,
24+
"H": 0.612,
2525
"He": 0.936,
2626
"Li": 1.8,
2727
"Be": 1.56,
@@ -169,16 +169,48 @@ def check(
169169
pbc=(not frame.nopbc),
170170
)
171171

172-
P = [[2, 0, 0], [0, 2, 0], [0, 0, 2]]
173-
extended_structure = make_supercell(structure, P)
174-
175-
coords = extended_structure.positions
176-
symbols = extended_structure.get_chemical_symbols()
172+
coords = structure.positions
173+
symbols = structure.get_chemical_symbols()
174+
cell, _ = structure.get_cell().standard_form()
175+
cell = cell.array
176+
177+
a1 = cell[0]
178+
a2 = cell[1]
179+
a3 = cell[2]
180+
181+
all_combinations = {
182+
"a1": np.linalg.norm(a1),
183+
"a2": np.linalg.norm(a2),
184+
"a3": np.linalg.norm(a3),
185+
"a1+a2": np.linalg.norm(a1 + a2),
186+
"a1+a3": np.linalg.norm(a1 + a3),
187+
"a2+a3": np.linalg.norm(a2 + a3),
188+
"a1-a2": np.linalg.norm(a1 - a2),
189+
"a1-a3": np.linalg.norm(a1 - a3),
190+
"a2-a3": np.linalg.norm(a2 - a3),
191+
"a1+a2+a3": np.linalg.norm(a1 + a2 + a3),
192+
"a1+a2-a3": np.linalg.norm(a1 + a2 - a3),
193+
"a1-a2+a3": np.linalg.norm(a1 - a2 + a3),
194+
"a1-a2-a3": np.linalg.norm(a1 - a2 - a3),
195+
"-a1+a2+a3": np.linalg.norm(-a1 + a2 + a3),
196+
"-a1+a2-a3": np.linalg.norm(-a1 + a2 - a3),
197+
"-a1-a2+a3": np.linalg.norm(-a1 - a2 + a3),
198+
"-a1-a2-a3": np.linalg.norm(-a1 - a2 - a3),
199+
}
200+
201+
A = list(all_combinations.values())
202+
B = [safe_dist[type_i] * 2 for type_i in symbols]
203+
204+
for a in A:
205+
for b in B:
206+
if a < b:
207+
print(f"Lattice length {a:.3f} is less than safe distance {b:.3f} ")
208+
return False
177209

178210
num_atoms = len(coords)
179211
for i in range(num_atoms):
180212
for j in range(i + 1, num_atoms):
181-
dist = extended_structure.get_distance(i, j, mic=True)
213+
dist = structure.get_distance(i, j, mic=True)
182214
type_i = symbols[i]
183215
type_j = symbols[j]
184216
dr = safe_dist[type_i] + safe_dist[type_j]
@@ -269,9 +301,9 @@ def check(
269301
cell, _ = structure.get_cell().standard_form()
270302

271303
if (
272-
cell[1][0] > np.tan(self.theta / 180.0 * np.pi) * cell[1][1] # type: ignore
273-
or cell[2][0] > np.tan(self.theta / 180.0 * np.pi) * cell[2][2] # type: ignore
274-
or cell[2][1] > np.tan(self.theta / 180.0 * np.pi) * cell[2][2] # type: ignore
304+
np.abs(cell[1][0]) > np.tan(self.theta / 180.0 * np.pi) * cell[1][1] # type: ignore
305+
or np.abs(cell[2][0]) > np.tan(self.theta / 180.0 * np.pi) * cell[2][2] # type: ignore
306+
or np.abs(cell[2][1]) > np.tan(self.theta / 180.0 * np.pi) * cell[2][2] # type: ignore
275307
):
276308
logging.warning("Inclined box")
277309
return False

0 commit comments

Comments
 (0)