Skip to content

Commit c23a9ed

Browse files
authored
Prevent the default value from being used when values were already passed (#1675)
1 parent 202b682 commit c23a9ed

File tree

3 files changed

+91
-2
lines changed

3 files changed

+91
-2
lines changed

cwltool/argparser.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,43 @@ class DirectoryAppendAction(FSAppendAction):
817817
objclass = "Directory"
818818

819819

820+
class AppendAction(argparse.Action):
821+
"""An argparse action that clears the default values if any value is provided.
822+
823+
Attributes:
824+
_called (bool): Initially set to ``False``, changed if any value is appended.
825+
"""
826+
827+
def __init__(
828+
self,
829+
option_strings: List[str],
830+
dest: str,
831+
nargs: Any = None,
832+
**kwargs: Any,
833+
) -> None:
834+
"""Intialize."""
835+
super().__init__(option_strings, dest, **kwargs)
836+
self._called = False
837+
838+
def __call__(
839+
self,
840+
parser: argparse.ArgumentParser,
841+
namespace: argparse.Namespace,
842+
values: Union[str, Sequence[Any], None],
843+
option_string: Optional[str] = None,
844+
) -> None:
845+
g = getattr(namespace, self.dest, None)
846+
if g is None:
847+
g = []
848+
if self.default is not None and not self._called:
849+
# If any value was specified, we then clear the list of options before appending.
850+
# We cannot always clear the ``default`` attribute since it collects the ``values`` appended.
851+
self.default.clear()
852+
self._called = True
853+
g.append(values)
854+
setattr(namespace, self.dest, g)
855+
856+
820857
def add_argument(
821858
toolparser: argparse.ArgumentParser,
822859
name: str,
@@ -864,7 +901,7 @@ def add_argument(
864901
elif inptype["items"] == "Directory":
865902
action = DirectoryAppendAction
866903
else:
867-
action = "append"
904+
action = AppendAction
868905
elif isinstance(inptype, MutableMapping) and inptype["type"] == "enum":
869906
atype = str
870907
elif isinstance(inptype, MutableMapping) and inptype["type"] == "record":

tests/default_values_list.cwl

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/usr/bin/env cwl-runner
2+
# From https://github.com/common-workflow-language/cwltool/issues/1632
3+
4+
cwlVersion: v1.2
5+
class: CommandLineTool
6+
7+
baseCommand: [cat]
8+
9+
stdout: "cat_file"
10+
11+
inputs:
12+
file_paths:
13+
type: string[]?
14+
inputBinding:
15+
position: 1
16+
default: ["/home/bart/cwl_test/test1"]
17+
18+
outputs:
19+
output:
20+
type: stdout

tests/test_toolargparse.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
from io import StringIO
33
from pathlib import Path
4-
from typing import Callable
4+
from typing import Any, Callable, Dict, List
55

66
import pytest
77

@@ -195,3 +195,35 @@ def test_argparser_without_doc() -> None:
195195
p = argparse.ArgumentParser()
196196
parser = generate_parser(p, tool, {}, [], False)
197197
assert parser.description is None
198+
199+
200+
@pytest.mark.parametrize(
201+
"job_order,expected_values",
202+
[
203+
# no arguments, so we expect the default value
204+
([], ["/home/bart/cwl_test/test1"]),
205+
# arguments, provided, one or many, meaning that the default value is not expected
206+
(["--file_paths", "/home/bart/cwl_test/test2"], ["/home/bart/cwl_test/test2"]),
207+
(
208+
[
209+
"--file_paths",
210+
"/home/bart/cwl_test/test2",
211+
"--file_paths",
212+
"/home/bart/cwl_test/test3",
213+
],
214+
["/home/bart/cwl_test/test2", "/home/bart/cwl_test/test3"],
215+
),
216+
],
217+
)
218+
def test_argparse_append_with_default(
219+
job_order: List[str], expected_values: List[str]
220+
) -> None:
221+
"""The appended arguments must not include the default. But if no appended argument, then the default is used."""
222+
loadingContext = LoadingContext()
223+
tool = load_tool(get_data("tests/default_values_list.cwl"), loadingContext)
224+
toolparser = generate_parser(
225+
argparse.ArgumentParser(prog="test"), tool, {}, [], False
226+
)
227+
cmd_line = vars(toolparser.parse_args(job_order))
228+
file_paths = list(cmd_line["file_paths"])
229+
assert expected_values == file_paths

0 commit comments

Comments
 (0)