7
7
import threading
8
8
from typing import Callable , Optional
9
9
10
- from .context import RuntimeContext
11
10
from .errors import WorkflowKillSwitch
12
11
from .loghandler import _logger
13
12
@@ -35,7 +34,7 @@ class TaskQueue:
35
34
in_flight : int = 0
36
35
"""The number of tasks in the queue."""
37
36
38
- def __init__ (self , lock : threading .Lock , thread_count : int , runtime_context : RuntimeContext ):
37
+ def __init__ (self , lock : threading .Lock , thread_count : int , kill_switch : threading . Event ):
39
38
"""Create a new task queue using the specified lock and number of threads."""
40
39
self .thread_count = thread_count
41
40
self .task_queue : queue .Queue [Optional [Callable [[], None ]]] = queue .Queue (
@@ -44,11 +43,7 @@ def __init__(self, lock: threading.Lock, thread_count: int, runtime_context: Run
44
43
self .task_queue_threads = []
45
44
self .lock = lock
46
45
self .error : Optional [BaseException ] = None
47
-
48
- if runtime_context .kill_switch is None :
49
- self .kill_switch = runtime_context .kill_switch = threading .Event ()
50
- else :
51
- self .kill_switch = runtime_context .kill_switch
46
+ self .kill_switch = kill_switch
52
47
53
48
for _r in range (0 , self .thread_count ):
54
49
t = threading .Thread (target = self ._task_queue_func )
0 commit comments