Skip to content

Commit f3d3278

Browse files
committed
PyscfCalculation: Validate parameters for unknown arguments
The validator for the `parameters` input is updated to check for any unknown arguments. This now raises an exception instead of silently ignoring them and starting the calculation.
1 parent 671245e commit f3d3278

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

src/aiida_pyscf/calculations/base.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""``CalcJob`` plugin for PySCF."""
33
from __future__ import annotations
44

5+
import copy
56
import io
67
import numbers
78
import pathlib
@@ -128,14 +129,15 @@ def define(cls, spec: CalcJobProcessSpec): # type: ignore[override]
128129
)
129130

130131
@classmethod
131-
def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: disable=too-many-return-statements,too-many-branches
132+
def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: disable=too-many-return-statements,too-many-branches,too-many-locals
132133
"""Validate the parameters input."""
133134
if not value:
134135
return None
135136

136-
parameters = value.get_dict()
137+
parameters = copy.deepcopy(value.get_dict())
137138

138-
mean_field_method = parameters.get('mean_field', {}).get('method')
139+
mean_field = parameters.pop('mean_field', {})
140+
mean_field_method = mean_field.pop('method', None)
139141
valid_methods = ['RKS', 'RHF', 'DKS', 'DHF', 'GKS', 'GHF', 'HF', 'KS', 'ROHF', 'ROKS', 'UKS', 'UHF']
140142
options = ' '.join(valid_methods)
141143

@@ -145,24 +147,24 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di
145147
if mean_field_method not in valid_methods:
146148
return f'Specified mean field method {mean_field_method} is not supported, choose from: {options}'
147149

148-
if 'chkfile' in parameters.get('mean_field', {}):
150+
if 'chkfile' in mean_field:
149151
return (
150152
'The `chkfile` cannot be specified in the `mean_field` parameters. It is set automatically by the '
151153
'plugin if the `checkpoint` input is provided.'
152154
)
153155

154-
if 'optimizer' in parameters:
156+
if (optimizer := parameters.pop('optimizer', None)) is not None:
155157
valid_solvers = ('geometric', 'berny')
156-
solver = parameters['optimizer'].get('solver')
158+
solver = optimizer.get('solver')
157159

158160
if solver is None:
159161
return f'No solver specified in `optimizer` parameters. Choose from: {valid_solvers}'
160162

161163
if solver.lower() not in valid_solvers:
162164
return f'Invalid solver `{solver}` specified in `optimizer` parameters. Choose from: {valid_solvers}'
163165

164-
if 'cubegen' in parameters:
165-
orbitals = parameters['cubegen'].get('orbitals')
166+
if (cubegen := parameters.pop('cubegen', None)) is not None:
167+
orbitals = cubegen.get('orbitals')
166168
indices = orbitals.get('indices') if orbitals is not None else None
167169

168170
if orbitals is not None and indices is None:
@@ -174,9 +176,9 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di
174176
if indices is not None and (not isinstance(indices, list) or any(not isinstance(e, int) for e in indices)):
175177
return f'The `cubegen.orbitals.indices` parameter should be a list of integers, but got: {indices}'
176178

177-
if 'fcidump' in parameters:
178-
active_spaces = parameters['fcidump'].get('active_spaces')
179-
occupations = parameters['fcidump'].get('occupations')
179+
if (fcidump := parameters.pop('fcidump', None)) is not None:
180+
active_spaces = fcidump.get('active_spaces')
181+
occupations = fcidump.get('occupations')
180182
arrays = []
181183

182184
for key, data in (('active_spaces', active_spaces), ('occupations', occupations)):
@@ -193,6 +195,13 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di
193195
if arrays[0].shape != arrays[1].shape:
194196
return 'The `fcipdump.active_spaces` and `fcipdump.occupations` arrays have different shapes.'
195197

198+
# Remove other known arguments
199+
for key in ('hessian', 'results', 'structure'):
200+
parameters.pop(key, None)
201+
202+
if unknown_keys := list(parameters.keys()):
203+
return f'The following arguments are not supported: {", ".join(unknown_keys)}'
204+
196205
def get_template_environment(self) -> Environment:
197206
"""Return the template environment that should be used for rendering.
198207

tests/calculations/test_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,15 @@ def test_invalid_parameters_mean_field_chkfile(generate_calc_job, generate_input
186186
generate_calc_job(PyscfCalculation, inputs=inputs)
187187

188188

189+
def test_invalid_parameters_unknown_arguments(generate_calc_job, generate_inputs_pyscf):
190+
"""Test validation of ``parameters`` raises if unknown arguments are included."""
191+
parameters = {'unknown_key': 'value'}
192+
inputs = generate_inputs_pyscf(parameters=parameters)
193+
194+
with pytest.raises(ValueError, match=r'The following arguments are not supported: unknown_key'):
195+
generate_calc_job(PyscfCalculation, inputs=inputs)
196+
197+
189198
@pytest.mark.parametrize(
190199
'parameters, expected', (
191200
({}, r'No solver specified in `optimizer` parameters'),

0 commit comments

Comments
 (0)