1111import logging
1212import argparse
1313import multiprocessing
14- from typing import TYPE_CHECKING , Any , List , Optional
14+ from typing import TYPE_CHECKING , Any , List , Type , TypeVar , Optional
1515from multiprocessing import connection
1616
17- from .remote import RemoteExecutor
1817from ...common .flag import flags
1918from ...common .constants import DEFAULT_THREADLESS , DEFAULT_NUM_WORKERS
2019
2120
2221if TYPE_CHECKING : # pragma: no cover
2322 from ..event import EventQueue
23+ from .threadless import Threadless
24+
25+ T = TypeVar ('T' , bound = 'Threadless[Any]' )
2426
2527logger = logging .getLogger (__name__ )
2628
@@ -70,6 +72,7 @@ class ThreadlessPool:
7072 def __init__ (
7173 self ,
7274 flags : argparse .Namespace ,
75+ executor_klass : Type ['T' ],
7376 event_queue : Optional ['EventQueue' ] = None ,
7477 ) -> None :
7578 self .flags = flags
@@ -79,7 +82,9 @@ def __init__(
7982 self .work_pids : List [int ] = []
8083 self .work_locks : List ['multiprocessing.synchronize.Lock' ] = []
8184 # List of threadless workers
82- self ._workers : List [RemoteExecutor ] = []
85+ self ._executor_klass = executor_klass
86+ # FIXME: Instead of Any type must be the executor klass
87+ self ._workers : List [Any ] = []
8388 self ._processes : List [multiprocessing .Process ] = []
8489
8590 def __enter__ (self ) -> 'ThreadlessPool' :
@@ -115,8 +120,8 @@ def _start_worker(self, index: int) -> None:
115120 self .work_locks .append (multiprocessing .Lock ())
116121 pipe = multiprocessing .Pipe ()
117122 self .work_queues .append (pipe [0 ])
118- w = RemoteExecutor (
119- iid = index ,
123+ w = self . _executor_klass (
124+ iid = str ( index ) ,
120125 work_queue = pipe [1 ],
121126 flags = self .flags ,
122127 event_queue = self .event_queue ,
0 commit comments