Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 47 additions & 28 deletions utils/update_checkout/update_checkout/parallel_runner.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,71 @@
from multiprocessing.managers import ListProxy, ValueProxy
import sys
from multiprocessing import cpu_count, Manager
from multiprocessing import cpu_count
import time
from typing import Callable, List, Any, Union
from typing import Callable, List, Any, Tuple, Union
from threading import Lock, Thread, Event
from concurrent.futures import ThreadPoolExecutor
import shutil

from .runner_arguments import RunnerArguments, AdditionalSwiftSourcesArguments


class TaskTracker:
_running_tasks: List[str]
_done_task_counter: int
_lock: Lock

def __init__(self):
self._running_tasks = []
self._done_task_counter = 0
self._lock = Lock()

def mark_task_as_running(self, task_name: str):
self._lock.acquire()
self._running_tasks.append(task_name)
self._lock.release()

def mark_task_as_done(self, task_name: str):
self._lock.acquire()
if task_name in self._running_tasks:
self._running_tasks.remove(task_name)
self._done_task_counter += 1
self._lock.release()

def status(self) -> Tuple[List[str], int]:
self._lock.acquire()
running_tasks_str = ", ".join(self.running_tasks)
done_tasks = self.done_task_counter
self._lock.release()
return running_tasks_str, done_tasks

@property
def running_tasks(self) -> List[str]:
return self._running_tasks

@property
def done_task_counter(self) -> int:
return self._done_task_counter


class MonitoredFunction:
def __init__(
self,
fn: Callable,
running_tasks: ListProxy,
updated_repos: ValueProxy,
lock: Lock,
task_tracker: TaskTracker,
):
self.fn = fn
self.running_tasks = running_tasks
self.updated_repos = updated_repos
self._lock = lock
self._task_tracker = task_tracker

def __call__(self, *args: Union[RunnerArguments, AdditionalSwiftSourcesArguments]):
task_name = args[0].repo_name
self.running_tasks.append(task_name)
self._task_tracker.mark_task_as_running(task_name)
result = None
try:
result = self.fn(*args)
except Exception as e:
print(e)
finally:
self._lock.acquire()
if task_name in self.running_tasks:
self.running_tasks.remove(task_name)
self.updated_repos.set(self.updated_repos.get() + 1)
self._lock.release()
self._task_tracker.mark_task_as_done(task_name)
return result


Expand All @@ -61,13 +90,8 @@ def __init__(
self._stop_event = Event()
self._verbose = pool_args[0].verbose
if not self._verbose:
manager = Manager()
self._lock = manager.Lock()
self._running_tasks = manager.list()
self._updated_repos = manager.Value("i", 0)
self._monitored_fn = MonitoredFunction(
self._fn, self._running_tasks, self._updated_repos, self._lock
)
self._task_tracker = TaskTracker()
self._monitored_fn = MonitoredFunction(self._fn, self._task_tracker)

def run(self) -> List[Any]:
print(f"Running ``{self._fn.__name__}`` with up to {self._n_threads} processes.")
Expand All @@ -86,12 +110,7 @@ def run(self) -> List[Any]:
def _monitor(self):
last_output = ""
while not self._stop_event.is_set():
self._lock.acquire()
current = list(self._running_tasks)
current_line = ", ".join(current)
updated_repos = self._updated_repos.get()
self._lock.release()

current_line, updated_repos = self._task_tracker.status()
if current_line != last_output:
truncated = f"{self._output_prefix} [{updated_repos}/{self._nb_repos}] ({current_line})"
if len(truncated) > self._terminal_width:
Expand Down
37 changes: 0 additions & 37 deletions utils/update_checkout/update_checkout/update_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import os
import platform
import re
import tempfile
import sys
import traceback
from multiprocessing import freeze_support
Expand All @@ -28,28 +27,6 @@
SCRIPT_FILE = os.path.abspath(__file__)
SCRIPT_DIR = os.path.dirname(SCRIPT_FILE)

def is_unix_sock_path_too_long() -> bool:
"""Check if the unix socket_path exceeds the 104 limit (108 on Linux).

The multiprocessing module creates a socket in the TEMPDIR folder. The
socket path should not exceed:
- 104 bytes on macOS
- 108 bytes on Linux (https://www.man7.org/linux/man-pages/man7/unix.7.html)

Returns:
bool: Whether the socket path exceeds the limit. Always False on Windows.
"""

if os.name != "posix":
return False

MAX_UNIX_SOCKET_PATH = 104
# `tempfile.mktemp` is deprecated yet that's what the multiprocessing
# module uses internally: https://github.com/python/cpython/blob/c4e7d245d61ac4547ecf3362b28f64fc00aa88c0/Lib/multiprocessing/connection.py#L72
# Since we are not using the resulting file, it is safe to use this
# method.
sock_path = tempfile.mktemp(prefix="sock-", dir=tempfile.gettempdir())
return len(sock_path.encode("utf-8")) > MAX_UNIX_SOCKET_PATH

def confirm_tag_in_repo(repo_path: str, tag: str, repo_name: str) -> Optional[str]:
"""Confirm that a given tag exists in a git repository. This function
Expand Down Expand Up @@ -823,20 +800,6 @@ def main():
"specify --scheme=foo")
sys.exit(1)

if is_unix_sock_path_too_long():
if not args.dump_hashes and not args.dump_hashes_config:
# Do not print anything other than the json dump.
print(
f"TEMPDIR={tempfile.gettempdir()} is too long and multiprocessing "
"sockets will exceed the size limit. Falling back to verbose mode."
)
args.verbose = True
if sys.version_info.minor < 10:
if not args.dump_hashes and not args.dump_hashes_config:
# Do not print anything other than the json dump.
print("Falling back to verbose mode due to a Python 3.9 limitation.")
args.verbose = True

clone = args.clone
clone_with_ssh = args.clone_with_ssh
skip_history = args.skip_history
Expand Down