Skip to content

Commit d44363e

Browse files
Fixed Incar object to allow for ML_MODE vasp tag (materialsproject#3625)
* Fixed Incar object to allow for ML_MODE vasp tag which does not want to have capitalized values (train, run, ...). By default, pymatgen capitalizes the values of the INCAR tags (e.g. "ALGO = Fast", even if you set incar["ALGO"] = "fast"). This is not working for the ML_MODE tag. Also fixed (i.e. ignored specific mypy code) random mypy errors in files that were not touched ... * rename lowerstr_keys -> lower_str_keys --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 988da0c commit d44363e

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

pymatgen/core/units.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
conversion factors. What matters is the relative values, not the absolute.
4242
The SI units must have factor 1.
4343
"""
44-
BASE_UNITS = {
44+
BASE_UNITS: dict[str, dict] = {
4545
"length": {
4646
"m": 1,
4747
"km": 1000,

pymatgen/io/vasp/inputs.py

+4
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,7 @@ def proc_val(key: str, val: Any):
873873
"LDAUTYPE",
874874
"IVDW",
875875
)
876+
lower_str_keys = ("ML_MODE",)
876877

877878
def smart_int_or_float(num_str):
878879
if num_str.find(".") != -1 or num_str.lower().find("e") != -1:
@@ -904,6 +905,9 @@ def smart_int_or_float(num_str):
904905
if key in int_keys:
905906
return int(re.match(r"^-?[0-9]+", val).group(0)) # type: ignore
906907

908+
if key in lower_str_keys:
909+
return val.strip().lower()
910+
907911
except ValueError:
908912
pass
909913

tests/io/vasp/test_inputs.py

+3
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,9 @@ def test_types(self):
833833

834834
def test_proc_types(self):
835835
assert Incar.proc_val("HELLO", "-0.85 0.85") == "-0.85 0.85"
836+
assert Incar.proc_val("ML_MODE", "train") == "train"
837+
assert Incar.proc_val("ML_MODE", "RUN") == "run"
838+
assert Incar.proc_val("ALGO", "fast") == "Fast"
836839

837840
def test_check_params(self):
838841
# Triggers warnings when running into invalid parameters

0 commit comments

Comments
 (0)