From 200cddcb1abad216970c53fef3fb9253b9d8bab8 Mon Sep 17 00:00:00 2001 From: pwwang Date: Sat, 10 Aug 2024 23:09:16 -0700 Subject: [PATCH] fix: fix type checking in `utils.load_pipeline()` --- pipen/utils.py | 26 ++++++++++---------------- tests/helpers.py | 11 +++++++++++ tests/test_utils.py | 7 +++++++ 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/pipen/utils.py b/pipen/utils.py index 3193da94..59a8f8ac 100644 --- a/pipen/utils.py +++ b/pipen/utils.py @@ -667,15 +667,15 @@ async def load_pipeline( try: if isinstance(obj, str): obj = _get_obj_from_spec(obj) - if isinstance(obj, Pipen) or ( - isinstance(obj, type) and issubclass(obj, (Pipen, Proc, ProcGroup)) - ): - pass - else: - raise TypeError( - "Expected a Pipen, Proc, ProcGroup class, or a Pipen object, " - f"got {type(obj)}" - ) + if isinstance(obj, Pipen) or ( + isinstance(obj, type) and issubclass(obj, (Pipen, Proc, ProcGroup)) + ): + pass + else: + raise TypeError( + "Expected a Pipen, Proc, ProcGroup class, or a Pipen object, " + f"got {type(obj)}" + ) pipeline = obj if isinstance(obj, type) and issubclass(obj, Proc): @@ -689,15 +689,9 @@ async def load_pipeline( # Avoid "pipeline" to be used as pipeline name by varname (pipeline, ) = (obj(**kwargs), ) # type: ignore - else: # obj is a Pipen instance + elif isinstance(obj, Pipen): pipeline._kwargs.update(kwargs) - if not isinstance(pipeline, Pipen): - raise TypeError( - "Expected a Pipen, Proc or ProcGroup class, " - f"got {type(pipeline)}" - ) - # Initialize the pipeline so that the arguments definied by # other plugins (i.e. pipen-args) to take in place. pipeline.workdir = Path(pipeline.config.workdir).joinpath( diff --git a/tests/helpers.py b/tests/helpers.py index 95d219f2..bfc0821b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -254,3 +254,14 @@ def create_dead_link(path): link.unlink() link.symlink_to(target) target.unlink() + + +# for load_pipeline tests +pipeline = Pipen( + name=f"simple_pipeline_{Pipen.PIPELINE_COUNT + 1}", + desc="No description", + loglevel="debug", + cache=True, + workdir=gettempdir() + "/.pipen", + outdir=gettempdir() + f"/pipen_simple_{Pipen.PIPELINE_COUNT}", +).set_starts(SimpleProc) diff --git a/tests/test_utils.py b/tests/test_utils.py index afbb5492..9b41b4d7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -196,6 +196,13 @@ class P1(Proc): assert len(pipeline.procs) == 1 +@pytest.mark.forked +@pytest.mark.asyncio +async def test_load_pipeline_pipen_object(tmp_path): + p = await load_pipeline(f"{HERE}/helpers.py:pipeline", a=1) + assert p._kwargs["a"] == 1 + + @pytest.mark.forked # To avoid: Another plugin named simpleplugin has already been registered. @pytest.mark.asyncio