diff --git a/README.md b/README.md index b1efbb5..467e996 100644 --- a/README.md +++ b/README.md @@ -1101,6 +1101,32 @@ if code cannot establish connection, it will retry using deploy port (+1 for def If code cannot establish connection, it will start deployment of python using [Deployment setup tool](#deployment-setup-tool) and establish connection again. +## RShell connection + +`RShellConnection` is a connection type that leverages `httplib.client` library to establish and manage connections using RESTful API over HTTP protocol in EFI Shell environment. `rshell_client.py` must be present on the EFI Shell target system. + +`rshell_server.py` is a server script that needs to be executed on the host machine to facilitate communication between the host and the EFI Shell target system. Server works on queue with address ip -> command to execute, it provides a RESTful API that allows the host to send commands and receive responses from the EFI Shell: + + * `/execute_command` - Endpoint to execute commands on the EFI Shell target system: + Form fields: + * `timeout` - Timeout for command execution. + * `command` - Command to be executed. + * `ip` - IP address of the EFI Shell target system. + * `/post_result` - Endpoint to post results back to the host. + Headers fields: + * `CommandID` - Unique identifier for the command. + * `rc` - Return code of the executed command. + Body: + * Command output. + * `/exception` - Endpoint to handle exceptions that may occur during communication. + Headers fields: + * `CommandID` - Unique identifier for the command. + Body: + * Exception details. + * `/getCommandToExecute` - Endpoint to retrieve commands to be executed on the EFI Shell target system. Returns commandline with generated CommandID. + * `/health/` - Endpoint to check the health status of the connection. + +`rshell.py` is a Connection class that calls RESTful API endpoints provided by `rshell_server.py` to execute commands on the EFI Shell target system. If required, starts `rshell_server.py` on the host machine. ## OS supported: * LNX diff --git a/examples/rshell_example.py b/examples/rshell_example.py new file mode 100644 index 0000000..8b0bc8b --- /dev/null +++ b/examples/rshell_example.py @@ -0,0 +1,11 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: MIT +import logging +logging.basicConfig(level=logging.DEBUG) +from mfd_connect.rshell import RShellConnection + +# LINUX +conn = RShellConnection(ip="10.10.10.10") # start and connect to rshell server +# conn = RShellConnection(ip="10.10.10.10", server_ip="10.10.10.11") # connect to rshell server +conn.execute_command("ls") +conn.disconnect(True) diff --git a/mfd_connect/rshell.py b/mfd_connect/rshell.py new file mode 100644 index 0000000..369aa62 --- /dev/null +++ b/mfd_connect/rshell.py @@ -0,0 +1,233 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: MIT +"""RShell Connection Class.""" + +import logging +import sys +import time +import typing +from ipaddress import IPv4Address, IPv6Address +from subprocess import CalledProcessError + +import requests +from mfd_common_libs import add_logging_level, log_levels, TimeoutCounter +from mfd_typing.cpu_values import CPUArchitecture +from mfd_typing.os_values import OSBitness, OSName, OSType + +from mfd_connect.local import LocalConnection +from mfd_connect.pathlib.path import CustomPath, custom_path_factory +from mfd_connect.process.base import RemoteProcess + +from .base import Connection, ConnectionCompletedProcess + +if typing.TYPE_CHECKING: + from pydantic import ( + BaseModel, # from pytest_mfd_config.models.topology import ConnectionModel + ) + + +logger = logging.getLogger(__name__) +add_logging_level(level_name="MODULE_DEBUG", level_value=log_levels.MODULE_DEBUG) +add_logging_level(level_name="CMD", level_value=log_levels.CMD) +add_logging_level(level_name="OUT", level_value=log_levels.OUT) + + +class RShellConnection(Connection): + """RShell Connection Class.""" + + def __init__( + self, + ip: str | IPv4Address | IPv6Address, + server_ip: str | IPv4Address | IPv6Address | None = "127.0.0.1", + model: "BaseModel | None" = None, + cache_system_data: bool = True, + connection_timeout: int = 60, + ): + """ + Initialize RShellConnection. + + :param ip: The IP address of the RShell server. + :param server_ip: The IP address of the server to connect to (optional). + :param model: The Pydantic model to use for the connection (optional). + :param cache_system_data: Whether to cache system data (default: True). + """ + super().__init__(model=model, cache_system_data=cache_system_data) + self._ip = ip + self.server_ip = server_ip if server_ip else "127.0.0.1" + self.server_process = None + if server_ip == "127.0.0.1": + # start Rshell server + self.server_process = self._run_server() + time.sleep(5) + timeout = TimeoutCounter(connection_timeout) + while not timeout: + logger.log(level=log_levels.MODULE_DEBUG, msg="Checking RShell server health") + status_code = requests.get( + f"http://{self.server_ip}/health/{self._ip}", proxies={"no_proxy": "*"} + ).status_code + if status_code == 200: + logger.log(level=log_levels.MODULE_DEBUG, msg="RShell server is healthy") + break + time.sleep(5) + else: + raise TimeoutError("Connection of Client to RShell server timed out") + + def disconnect(self, stop_client: bool = False) -> None: + """ + Disconnect connection. + + Stop local RShell server if established. + + :param stop_client: Whether to stop the RShell client (default: False). + """ + if stop_client: + logger.log(level=log_levels.MODULE_DEBUG, msg="Stopping RShell client") + self.execute_command("end") + if self.server_process: + logger.log(level=log_levels.MODULE_DEBUG, msg="Stopping RShell server") + self.server_process.kill() + logger.log(level=log_levels.MODULE_DEBUG, msg="RShell server stopped") + logger.log(level=log_levels.MODULE_DEBUG, msg=self.server_process.stdout_text) + + def _run_server(self) -> RemoteProcess: + """Run RShell server locally.""" + conn = LocalConnection() + server_file = conn.path(__file__).parent / "rshell_server.py" + return conn.start_process(f"{conn.modules().sys.executable} {server_file}") + + def execute_command( + self, + command: str, + *, + input_data: str | None = None, + cwd: str | None = None, + timeout: int | None = None, + env: dict | None = None, + stderr_to_stdout: bool = False, + discard_stdout: bool = False, + discard_stderr: bool = False, + skip_logging: bool = False, + expected_return_codes: list[int] | None = None, + shell: bool = False, + custom_exception: type[CalledProcessError] | None = None, + ) -> ConnectionCompletedProcess: + """ + Execute a command on the remote server. + + :param command: The command to execute. + :param timeout: The timeout for the command execution (optional). + :return: The result of the command execution. + """ + if input_data is not None: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Input data is not supported for RShellConnection and will be ignored.", + ) + + if cwd is not None: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="CWD is not supported for RShellConnection and will be ignored.", + ) + + if env is not None: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Environment variables are not supported for RShellConnection and will be ignored.", + ) + + if stderr_to_stdout: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Redirecting stderr to stdout is not supported for RShellConnection and will be ignored.", + ) + + if discard_stdout: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Discarding stdout is not supported for RShellConnection and will be ignored.", + ) + + if discard_stderr: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Discarding stderr is not supported for RShellConnection and will be ignored.", + ) + + if skip_logging: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Skipping logging is not supported for RShellConnection and will be ignored.", + ) + + if expected_return_codes is not None: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Expected return codes are not supported for RShellConnection and will be ignored.", + ) + + if shell: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Shell execution is not supported for RShellConnection and will be ignored.", + ) + + if custom_exception: + logger.log( + level=log_levels.MODULE_DEBUG, + msg="Custom exceptions are not supported for RShellConnection and will be ignored.", + ) + timeout_string = f" with timeout {timeout} seconds" if timeout is not None else "" + logger.log(level=log_levels.CMD, msg=f"Executing >{self._ip}> '{command}',{timeout_string}") + + response = requests.post( + f"http://{self.server_ip}/execute_command", + data={"command": command, "timeout": timeout, "ip": self._ip}, + proxies={"no_proxy": "*"}, + ) + completed_process = ConnectionCompletedProcess( + args=command, + stdout=response.text, + return_code=int(response.headers.get("rc", -1)), + ) + logger.log( + level=log_levels.MODULE_DEBUG, + msg=f"Finished executing '{command}', rc={completed_process.return_code}", + ) + if skip_logging: + return completed_process + + stdout = completed_process.stdout + if stdout: + logger.log(level=log_levels.OUT, msg=f"stdout>>\n{stdout}") + + return completed_process + + def path(self, *args, **kwargs) -> CustomPath: + """Path represents a filesystem path.""" + if sys.version_info >= (3, 12): + kwargs["owner"] = self + return custom_path_factory(*args, **kwargs) + + return CustomPath(*args, owner=self, **kwargs) + + def get_os_name(self) -> OSName: # noqa: D102 + raise NotImplementedError + + def get_os_type(self) -> OSType: # noqa: D102 + raise NotImplementedError + + def get_os_bitness(self) -> OSBitness: # noqa: D102 + raise NotImplementedError + + def get_cpu_architecture(self) -> CPUArchitecture: # noqa: D102 + raise NotImplementedError + + def restart_platform(self) -> None: # noqa: D102 + raise NotImplementedError + + def shutdown_platform(self) -> None: # noqa: D102 + raise NotImplementedError + + def wait_for_host(self, timeout: int = 60) -> None: # noqa: D102 + raise NotImplementedError diff --git a/mfd_connect/rshell_client.py b/mfd_connect/rshell_client.py new file mode 100644 index 0000000..183d6ac --- /dev/null +++ b/mfd_connect/rshell_client.py @@ -0,0 +1,119 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: MIT +""" +RShell Client Script. + +Make sure that the Python UEFI interpreter is compiled with +Socket module support. +""" + +__version__ = "1.0.0" + +try: + import httplib as client +except ImportError: + from http import client +import sys +import os +import time + +# get http server ip +http_server = sys.argv[1] +if len(sys.argv) > 2: + source_address = sys.argv[2] +else: + source_address = None + +os_name = os.name + + +def _sleep(interval): # noqa: ANN001, ANN202 + """ + Simulate the sleep function for EFI shell as the sleep API from time module is not working on EFI shell. + + :param interval time period the system to be in idle + """ + start_ts = time.time() + while time.time() < start_ts + interval: + pass + + +time.sleep = _sleep + + +def _get_command(): # noqa: ANN202 + """Get the command from server to execute on client machine.""" + # construct the list of tests by interacting with server + conn.request("GET", "getCommandToExecute") + rsp = conn.getresponse() + status = rsp.status + _id = rsp.getheader("CommandID") + if status == 204: + return None + + print("Waiting for command from server: ") + data_received = rsp.read() + print(data_received) + test_list = data_received.split(b",") + + return test_list[0], _id # return only the first command + + +while True: + # Connect to server + source_address_parameter = (source_address, 80) if source_address else None + conn = client.HTTPConnection(http_server, source_address=source_address_parameter) + # get the command from server + _command = _get_command() + if not _command: + conn.close() + time.sleep(5) + continue + cmd_str, _id = _command + cmd_str = cmd_str.decode("utf-8") + cmd_name = cmd_str.split(" ")[0] + if cmd_name == "end": + print("No more commands available to run") + conn.close() + exit(0) + + print("Executing", cmd_str) + + out = cmd_name + ".txt" + cmd = cmd_str + " > " + out + + time.sleep(5) + rc = os.system(cmd) # execute command on machine + print("Executed the command") + time.sleep(5) + + print("Posting the results to server") + # send response to server + try: + if os_name == "edk2": + encoding = "utf-16" + else: + encoding = "utf-8" + + f = open(out, "r", encoding=encoding) + + conn.request( + "POST", + "post_result", + body=f.read(), + headers={"Content-Type": "text/plain", "Connection": "keep-alive", "CommandID": _id, "rc": rc}, + ) + f.close() + os.system("del " + out) + except Exception as exp: + conn.request( + "POST", + "exception", + body=cmd + str(exp), + headers={"Content-Type": "text/plain", "Connection": "keep-alive", "CommandID": _id}, + ) + + print("output posted to server") + conn.close() + print("closed the connection") + time.sleep(1) diff --git a/mfd_connect/rshell_server.py b/mfd_connect/rshell_server.py new file mode 100644 index 0000000..c1585b3 --- /dev/null +++ b/mfd_connect/rshell_server.py @@ -0,0 +1,162 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: MIT +""" +RShell Server Script. + +This script implements a RESTful server using Flask to manage command execution +on connected RShell clients. +""" + +import time +from collections import namedtuple +from queue import Queue +from uuid import uuid4 + +from flask import Flask, Response, request + +__version__ = "1.0.0" + +# Global command queue +output_object = namedtuple("OutputObject", ["output", "rc"]) +command_object = namedtuple("CommandObject", ["command_id", "str"]) + +output_queue: dict[str, output_object] = dict() +command_dict_queue: dict[str, Queue] = dict() +clients: list = [] + +app = Flask(__name__) + + +def get_output(command_id: str, timeout: float = 600) -> output_object: + """ + Retrieve the output for a given command ID. + + :param command_id: The ID of the command to retrieve output for. + :param timeout: The maximum time to wait for output (in seconds). + :return: The output for the given command ID. + :raises TimeoutError: If the command times out. + """ + print("Getting output for command ID:", command_id) + print(f"Waiting for output {timeout} seconds") + timeout = timeout + 5 # add time for client loop waiting + while timeout > 0: + result = output_queue.get(command_id, None) + if result is not None: + return result + time.sleep(1) + timeout -= 1 + raise TimeoutError("Command timed out") + + +def add_command_to_queue(command: str, ip_address: str) -> str: + """ + Add a command to the global command queue. + + :param command: The command to add to the queue. + :param ip_address: The IP address of the client. + :return: The ID of the added command. + """ + print("Adding command to queue:", command) + _id = str(uuid4().int) + if command_dict_queue.get(ip_address) is None: + command_dict_queue[ip_address] = Queue() + command_dict_queue[ip_address].put(command_object(command_id=_id, str=command)) + return _id + + +@app.route("/health/", methods=["GET"]) +def health_check(ip: str) -> Response: + """Health check endpoint.""" + if ip in clients: + return Response("OK", status=200) + else: + return Response("Client not connected", status=503) + + +@app.route("/getCommandToExecute", methods=["GET"]) +def get_command_to_execute() -> Response: + """ + Get the next command to execute for the connected client. + + :return: The next command to execute. + """ + ip_address = str(request.remote_addr) + if ip_address not in clients: + print(f"Client connected: {ip_address}") + clients.append(ip_address) + client_queue = command_dict_queue.get(ip_address, Queue()) + if not client_queue.empty(): + command_object = client_queue.get() + return Response( + command_object.str, + status=200, + mimetype="text/plain", + headers={"CommandID": command_object.command_id}, + ) + else: + return Response("No more elements left in the queue", status=204) + + +@app.route("/exception", methods=["POST"]) +def post_exception() -> Response: + """ + Receive exception details from the client. + + :param body: The exception details. + :param CommandID: The ID of the command that caused the exception. + :return: A response indicating the exception was received. + """ + read_data = request.data + command_id = str(request.headers.get("CommandID")) + print("CommandID: ", command_id) + print(str(read_data, encoding="utf-8")) + output_queue[command_id] = output_object(output=str(read_data, encoding="utf-8"), rc=-1) + return Response("Exception received", status=200) + + +@app.route("/execute_command", methods=["POST"]) +def execute_command() -> Response: + """ + Execute a command on the connected client. + + :param command: The command to execute. + :param timeout: The maximum time to wait for command execution (in seconds). + :param ip: The IP address of the client. + :return: The output of the executed command. + """ + timeout = int(request.form.get("timeout", 600)) + command = request.form.get("command") + ip_address = str(request.form.get("ip")) + if command: + _id = add_command_to_queue(command, ip_address) + if command == "end": + return Response("No more commands available to run", status=200) + process = get_output(_id, timeout) + return Response( + process.output.encode("utf-8"), + status=200, + headers={ + "Content-type": "text/plain", + "CommandID": _id, + "rc": process.rc, + }, + ) + else: + return Response("No command provided", status=400) + + +@app.route("/post_result", methods=["POST"]) +def post_result() -> Response: + """Receive command execution results from the client.""" + read_data = request.data + command_id = str(request.headers.get("CommandID")) + rc = int(request.headers.get("rc", -1)) + print("CommandID: ", command_id) + print(str(read_data, encoding="utf-8")) + output_queue[command_id] = output_object(output=str(read_data, encoding="utf-8"), rc=rc) + return Response("Results received", status=200) + + +if __name__ == "__main__": + print("Starting Flask REST server...") + app.run(host="0.0.0.0", port=80) diff --git a/mfd_connect/ssh.py b/mfd_connect/ssh.py index 8bee6e6..0d8f269 100644 --- a/mfd_connect/ssh.py +++ b/mfd_connect/ssh.py @@ -562,7 +562,7 @@ def execute_command( logger.log( level=log_levels.MFD_INFO, msg="[Warning] A pseudo-terminal was requested, " - "but please be aware that this is not recommended and may lead to unexpected behavior.", + "but please be aware that this is not recommended and may lead to unexpected behavior.", ) self._verify_command_correctness(command) diff --git a/requirements.txt b/requirements.txt index 4b928c4..9b02308 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ psutil~=5.9.5; platform_system != 'VMkernel' mfd-ftp>=1.8.0 pywinrm~=0.4.3 netaddr -telnetlib-313-and-up; python_version >= '3.13' \ No newline at end of file +telnetlib-313-and-up; python_version >= '3.13' +flask \ No newline at end of file