Skip to content

Commit

Permalink
test: move test wf with defualt val under standard test positional args
Browse files Browse the repository at this point in the history
Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima committed Feb 22, 2025
1 parent 818555f commit ced38bd
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,20 +985,17 @@ def wf_mixed_positional_and_keyword_args() -> int:
assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret

def test_positional_args_workflow_with_default_value():
def test_positional_args_workflow():
arg1 = 5
arg2 = 6
default_arg1 = 1
default_arg2 = 2
ret = 17
ret_arg2_default = 9

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def sub_wf(x: int = default_arg1, y: int = default_arg2) -> int:
def sub_wf(x: int, y: int) -> int:
return t1(x=x, y=y)

@workflow
Expand All @@ -1009,17 +1006,11 @@ def wf_pure_positional_args() -> int:
def wf_mixed_positional_and_keyword_args() -> int:
return sub_wf(arg1, y=arg2)

@workflow
def wf_mixed_positional_and_default_value() -> int:
return sub_wf(arg1)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)
wf_mixed_positional_and_default_value_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_default_value)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
default_arg2_binding = Scalar(primitive=Primitive(integer=default_arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
Expand All @@ -1030,25 +1021,23 @@ def wf_mixed_positional_and_default_value() -> int:
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_mixed_positional_and_default_value_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_default_value_spec.template.nodes[0].inputs[1].binding.value == default_arg2_binding
assert wf_mixed_positional_and_default_value_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret
assert wf_mixed_positional_and_default_value() == ret_arg2_default

def test_positional_args_workflow():
def test_positional_args_workflow_with_default_value():
arg1 = 5
arg2 = 6
default_arg1 = 1
default_arg2 = 2
ret = 17
ret_arg2_default = 9

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def sub_wf(x: int, y: int) -> int:
def sub_wf(x: int = default_arg1, y: int = default_arg2) -> int:
return t1(x=x, y=y)

@workflow
Expand All @@ -1059,11 +1048,17 @@ def wf_pure_positional_args() -> int:
def wf_mixed_positional_and_keyword_args() -> int:
return sub_wf(arg1, y=arg2)

@workflow
def wf_mixed_positional_and_default_value() -> int:
return sub_wf(arg1)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)
wf_mixed_positional_and_default_value_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_default_value)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
default_arg2_binding = Scalar(primitive=Primitive(integer=default_arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
Expand All @@ -1074,8 +1069,13 @@ def wf_mixed_positional_and_keyword_args() -> int:
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_mixed_positional_and_default_value_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_default_value_spec.template.nodes[0].inputs[1].binding.value == default_arg2_binding
assert wf_mixed_positional_and_default_value_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret
assert wf_mixed_positional_and_default_value() == ret_arg2_default


def test_positional_args_workflow_extra_args_or_kwargs():
Expand Down

0 comments on commit ced38bd

Please sign in to comment.