Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/python/common/app_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import signal
import sys
import threading
import time
from abc import abstractmethod
from datetime import datetime
from multiprocessing import Event, Process, Queue
Expand Down Expand Up @@ -114,7 +115,7 @@ def elapsed_ms(start):

timestamp_start = datetime.now()
while self.is_alive() and elapsed_ms(timestamp_start) < AppProcess.__DEFAULT_TERMINATE_TIMEOUT_MS:
pass
time.sleep(0.05)

super().terminate()

Expand Down
23 changes: 15 additions & 8 deletions src/python/controller/extract/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def add_listener(self, listener: ExtractListener):
self.__listeners_lock.release()

def status(self) -> list[ExtractStatus]:
tasks = list(self.__task_queue.queue)
with self.__task_queue.mutex:
tasks = list(self.__task_queue.queue)
statuses = []
for task in tasks:
status = ExtractStatus(
Expand Down Expand Up @@ -130,10 +131,11 @@ def extract(self, req: ExtractRequest):
model_file = req.model_file
self.logger.debug("Received extract for {}".format(model_file.name))

for task in self.__task_queue.queue:
if task.root_name == model_file.name and task.pair_id == req.pair_id:
self.logger.info("Ignoring extract for {}, already exists".format(model_file.name))
return
with self.__task_queue.mutex:
for task in self.__task_queue.queue:
if task.root_name == model_file.name and task.pair_id == req.pair_id:
self.logger.info("Ignoring extract for {}, already exists".format(model_file.name))
return

# noinspection PyProtectedMember
task = ExtractDispatch._Task(model_file.name, model_file.is_dir, req.pair_id)
Expand Down Expand Up @@ -183,9 +185,11 @@ def __worker(self):
while not self.__worker_shutdown.is_set():
# Try to grab next task
# Do another check for shutdown
while len(self.__task_queue.queue) > 0 and not self.__worker_shutdown.is_set():
# peek the task
task = self.__task_queue.queue[0]
with self.__task_queue.mutex:
has_tasks = len(self.__task_queue.queue) > 0
while has_tasks and not self.__worker_shutdown.is_set():
with self.__task_queue.mutex:
task = self.__task_queue.queue[0]

# We have a task, extract archives one by one
completed = True
Expand Down Expand Up @@ -217,6 +221,9 @@ def __worker(self):
listener.extract_failed(task.root_name, task.root_is_dir, task.pair_id)
self.__listeners_lock.release()

with self.__task_queue.mutex:
has_tasks = len(self.__task_queue.queue) > 0

time.sleep(ExtractDispatch.__WORKER_SLEEP_INTERVAL_IN_SECS)

self.logger.debug("Stopped worker thread")
Expand Down
3 changes: 2 additions & 1 deletion src/python/controller/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def _send_post(self, url: str, payload: dict):
try:
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(url, data=data, headers={"Content-Type": "application/json"}, method="POST")
urllib.request.urlopen(req, timeout=5)
with urllib.request.urlopen(req, timeout=5):
pass
self._logger.debug("Webhook sent: %s %s", payload["event_type"], payload["filename"])
except Exception as e:
self._logger.warning("Webhook failed: %s", str(e))
14 changes: 10 additions & 4 deletions src/python/ssh/sshcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def __init__(self, host: str, port: int, user: str | None = None, password: str
self.__shell_detected: bool = False
self.logger = logging.getLogger(self.__class__.__name__)

def _remote_address(self) -> str:
"""Return 'user@host' when user is set, or just 'host' when None."""
if self.__user is not None:
return "{}@{}".format(self.__user, self.__host)
return self.__host

def set_base_logger(self, base_logger: logging.Logger):
self.logger = base_logger.getChild(self.__class__.__name__)

Expand Down Expand Up @@ -126,7 +132,7 @@ def _run_shell_command(self, command: str) -> bytes:
"-p",
str(self.__port),
]
args = ["{}@{}".format(self.__user, self.__host), quoted]
args = [self._remote_address(), quoted]
return self.__run_command(command="ssh", flags=" ".join(flags), args=" ".join(args))

def _check_remote_shells_via_sftp(self) -> list[str]:
Expand Down Expand Up @@ -162,7 +168,7 @@ def _sftp_stat(self, remote_path: str):
args = [
"-b",
"-", # read commands from stdin
"{}@{}".format(self.__user, self.__host),
self._remote_address(),
]

command_args = ["sftp"]
Expand Down Expand Up @@ -411,7 +417,7 @@ def shell(self, command: str) -> bytes:
"-p",
str(self.__port), # port
]
args = ["{}@{}".format(self.__user, self.__host), command]
args = [self._remote_address(), command]
return self.__run_command(command="ssh", flags=" ".join(flags), args=" ".join(args))

def copy(self, local_path: str, remote_path: str):
Expand All @@ -431,5 +437,5 @@ def copy(self, local_path: str, remote_path: str):
"-P",
str(self.__port), # port
]
args = [local_path, "{}@{}:{}".format(self.__user, self.__host, remote_path)]
args = [local_path, "{}:{}".format(self._remote_address(), remote_path)]
self.__run_command(command="scp", flags=" ".join(flags), args=" ".join(args))
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def create_archive(*args):
zf.write(temp_file_path, os.path.basename(temp_file_path))
zf.close()
elif ext == "rar":
fnull = open(os.devnull, "w")
subprocess.Popen(["rar", "a", "-ep", path, temp_file_path], stdout=fnull).communicate()
subprocess.run(["rar", "a", "-ep", path, temp_file_path], stdout=subprocess.DEVNULL, check=True)
else:
raise ValueError("Unsupported archive format: {}".format(os.path.basename(path)))
return os.path.getsize(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,24 @@ def setUpClass(cls):
zf.close()

# rar
fnull = open(os.devnull, "w")
TestExtract.ar_rar = os.path.join(archive_dir, "file.rar")
subprocess.Popen(["rar", "a", "-ep", TestExtract.ar_rar, temp_file], stdout=fnull)
subprocess.run(["rar", "a", "-ep", TestExtract.ar_rar, temp_file], stdout=subprocess.DEVNULL, check=True)

# rar split
subprocess.Popen(
["rar", "a", "-ep", "-m0", "-v50k", os.path.join(archive_dir, "file.split.rar"), temp_file], stdout=fnull
subprocess.run(
["rar", "a", "-ep", "-m0", "-v50k", os.path.join(archive_dir, "file.split.rar"), temp_file],
stdout=subprocess.DEVNULL,
check=True,
)
TestExtract.ar_rar_split_p1 = os.path.join(archive_dir, "file.split.part1.rar")
TestExtract.ar_rar_split_p2 = os.path.join(archive_dir, "file.split.part2.rar")

# tar.gz
TestExtract.ar_tar_gz = os.path.join(archive_dir, "file.tar.gz")
subprocess.Popen(
["tar", "czvf", TestExtract.ar_tar_gz, "-C", os.path.dirname(temp_file), os.path.basename(temp_file)]
subprocess.run(
["tar", "czvf", TestExtract.ar_tar_gz, "-C", os.path.dirname(temp_file), os.path.basename(temp_file)],
stdout=subprocess.DEVNULL,
check=True,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def _extract_archive(**kwargs):

time.sleep(0.1)

while self.mock_extract_archive.call_count < 1 and self.listener.extract_completed.call_count < 1:
while self.mock_extract_archive.call_count < 1 or self.listener.extract_completed.call_count < 1:
pass
time.sleep(0.1)
self.listener.extract_completed.assert_called_once_with("a", False, None)
Expand Down
19 changes: 19 additions & 0 deletions src/python/tests/unittests/test_ssh/test_sshcp_remote_address.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import unittest

from ssh import Sshcp


class TestSshcpRemoteAddress(unittest.TestCase):
"""Unit tests for Sshcp._remote_address helper."""

def test_remote_address_with_user(self):
sshcp = Sshcp(host="example.com", port=22, user="alice")
self.assertEqual("alice@example.com", sshcp._remote_address())

def test_remote_address_without_user(self):
sshcp = Sshcp(host="example.com", port=22, user=None)
self.assertEqual("example.com", sshcp._remote_address())

def test_remote_address_default_user(self):
sshcp = Sshcp(host="example.com", port=22)
self.assertEqual("example.com", sshcp._remote_address())