diff --git a/awscli/customizations/ecs/expressgateway/display_strategy.py b/awscli/customizations/ecs/expressgateway/display_strategy.py new file mode 100644 index 000000000000..1065232054ee --- /dev/null +++ b/awscli/customizations/ecs/expressgateway/display_strategy.py @@ -0,0 +1,165 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""Display strategy implementations for ECS Express Gateway Service monitoring.""" + +import asyncio +import time + +from botocore.exceptions import ClientError + +from awscli.customizations.utils import uni_print + + +class DisplayStrategy: + """Base class for display strategies. + + Each strategy controls its own execution model, timing, and output format. + """ + + def execute_monitoring(self, collector, start_time, timeout_minutes): + """Execute the monitoring loop. + + Args: + collector: ServiceViewCollector instance for data fetching + start_time: Start timestamp for timeout calculation + timeout_minutes: Maximum monitoring duration in minutes + """ + raise NotImplementedError + + +class InteractiveDisplayStrategy(DisplayStrategy): + """Interactive display strategy with async spinner and keyboard navigation. + + Uses dual async tasks: + - Data task: Polls ECS APIs every 5 seconds + - Spinner task: Updates display every 100ms with rotating spinner + """ + + def __init__(self, display, use_color): + """Initialize the interactive display strategy. + + Args: + display: Display instance from prompt_toolkit_display module + providing the interactive terminal interface + use_color: Whether to use colored output + """ + self.display = display + self.use_color = use_color + + def execute_monitoring(self, collector, start_time, timeout_minutes): + """Execute async monitoring with spinner and keyboard controls.""" + final_output, timed_out = asyncio.run( + self._execute_async(collector, start_time, timeout_minutes) + ) + if timed_out: + uni_print(final_output + "\nMonitoring timed out!\n") + else: + uni_print(final_output + "\nMonitoring Complete!\n") + + async def _execute_async(self, collector, start_time, timeout_minutes): + """Async execution with dual tasks for data and spinner.""" + spinner_chars = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" + spinner_index = 0 + current_output = "Waiting for initial data" + timed_out = False + + async def update_data(): + nonlocal current_output, timed_out + while True: + current_time = time.time() + if current_time - start_time > timeout_minutes * 60: + timed_out = True + self.display.app.exit() + break + + try: + loop = asyncio.get_event_loop() + new_output = await loop.run_in_executor( + None, collector.get_current_view, "{SPINNER}" + ) + current_output = new_output + except ClientError as e: + if ( + e.response.get('Error', {}).get('Code') + == 'InvalidParameterException' + ): + error_message = e.response.get('Error', {}).get( + 'Message', '' + ) + if ( + "Cannot call DescribeServiceRevisions for a service that is INACTIVE" + in error_message + ): + current_output = "Service is inactive" + else: + raise + else: + raise + + await asyncio.sleep(5.0) + + async def update_spinner(): + nonlocal spinner_index + while True: + spinner_char = spinner_chars[spinner_index] + display_output = current_output.replace( + "{SPINNER}", spinner_char + ) + status_text = f"Getting updates... {spinner_char} | up/down to scroll, q to quit" + self.display.display(display_output, status_text) + spinner_index = (spinner_index + 1) % len(spinner_chars) + await asyncio.sleep(0.1) + + data_task = asyncio.create_task(update_data()) + spinner_task = asyncio.create_task(update_spinner()) + display_task = None + + try: + display_task = asyncio.create_task(self.display.run()) + + done, pending = await asyncio.wait( + [display_task, data_task], return_when=asyncio.FIRST_COMPLETED + ) + + if data_task in done: + # Retrieve and re-raise any exception from the task. + # asyncio.wait() doesn't retrieve exceptions itself. + exc = data_task.exception() + if exc: + raise exc + + # Cancel pending tasks + for task in pending: + task.cancel() + # Await cancelled task to ensure proper cleanup and prevent + # warnings about unawaited tasks + try: + await task + except asyncio.CancelledError: + pass + + finally: + # Ensure display app is properly shut down + self.display.app.exit() + spinner_task.cancel() + if display_task is not None and not display_task.done(): + display_task.cancel() + # Await cancelled task to ensure proper cleanup and prevent + # warnings about unawaited tasks + try: + await display_task + except asyncio.CancelledError: + pass + + return current_output.replace("{SPINNER}", ""), timed_out diff --git a/awscli/customizations/ecs/monitorexpressgatewayservice.py b/awscli/customizations/ecs/monitorexpressgatewayservice.py index e946c010dce3..a5ed8769cf1c 100644 --- a/awscli/customizations/ecs/monitorexpressgatewayservice.py +++ b/awscli/customizations/ecs/monitorexpressgatewayservice.py @@ -39,15 +39,16 @@ aws ecs monitor-express-gateway-service --service-arn [--resource-view RESOURCE|DEPLOYMENT] """ -import asyncio import sys -import threading import time from botocore.exceptions import ClientError from awscli.customizations.commands import BasicCommand from awscli.customizations.ecs.exceptions import MonitoringError +from awscli.customizations.ecs.expressgateway.display_strategy import ( + InteractiveDisplayStrategy, +) from awscli.customizations.ecs.prompt_toolkit_display import Display from awscli.customizations.ecs.serviceviewcollector import ServiceViewCollector from awscli.customizations.utils import uni_print @@ -185,7 +186,7 @@ def __init__( service_arn, mode, timeout_minutes=30, - display=None, + display_strategy=None, use_color=True, collector=None, ): @@ -195,7 +196,9 @@ def __init__( self.timeout_minutes = timeout_minutes self.start_time = time.time() self.use_color = use_color - self.display = display or Display() + self.display_strategy = display_strategy or InteractiveDisplayStrategy( + display=Display(), use_color=use_color + ) self.collector = collector or ServiceViewCollector( client, service_arn, mode, use_color ) @@ -207,72 +210,6 @@ def is_monitoring_available(): def exec(self): """Start monitoring the express gateway service with progress display.""" - - def monitor_service(spinner_char): - return self.collector.get_current_view(spinner_char) - - asyncio.run(self._execute_with_progress_async(monitor_service, 100)) - - async def _execute_with_progress_async( - self, execution, progress_refresh_millis, execution_refresh_millis=5000 - ): - """Execute monitoring loop with animated progress display.""" - spinner_chars = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" - spinner_index = 0 - - current_output = "Waiting for initial data" - - async def update_data(): - nonlocal current_output - while True: - current_time = time.time() - if current_time - self.start_time > self.timeout_minutes * 60: - break - try: - loop = asyncio.get_event_loop() - new_output = await loop.run_in_executor( - None, execution, "{SPINNER}" - ) - current_output = new_output - except ClientError as e: - if ( - e.response.get('Error', {}).get('Code') - == 'InvalidParameterException' - ): - error_message = e.response.get('Error', {}).get( - 'Message', '' - ) - if ( - "Cannot call DescribeServiceRevisions for a service that is INACTIVE" - in error_message - ): - current_output = "Service is inactive" - else: - raise - else: - raise - await asyncio.sleep(execution_refresh_millis / 1000.0) - - async def update_spinner(): - nonlocal spinner_index - while True: - spinner_char = spinner_chars[spinner_index] - display_output = current_output.replace( - "{SPINNER}", spinner_char - ) - status_text = f"Getting updates... {spinner_char} | up/down to scroll, q to quit" - self.display.display(display_output, status_text) - spinner_index = (spinner_index + 1) % len(spinner_chars) - await asyncio.sleep(progress_refresh_millis / 1000.0) - - # Start both tasks - data_task = asyncio.create_task(update_data()) - spinner_task = asyncio.create_task(update_spinner()) - - try: - await self.display.run() - finally: - data_task.cancel() - spinner_task.cancel() - final_output = current_output.replace("{SPINNER}", "") - uni_print(final_output + "\nMonitoring Complete!\n") + self.display_strategy.execute_monitoring( + self.collector, self.timeout_minutes, self.start_time + ) diff --git a/tests/unit/customizations/ecs/expressgateway/test_display_strategy.py b/tests/unit/customizations/ecs/expressgateway/test_display_strategy.py new file mode 100644 index 000000000000..119b3d252a43 --- /dev/null +++ b/tests/unit/customizations/ecs/expressgateway/test_display_strategy.py @@ -0,0 +1,232 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ + +import asyncio +import time +from unittest.mock import Mock, patch + +import pytest +from botocore.exceptions import ClientError +from prompt_toolkit.application import create_app_session +from prompt_toolkit.output import DummyOutput + +from awscli.customizations.ecs.expressgateway.display_strategy import ( + DisplayStrategy, + InteractiveDisplayStrategy, +) + + +class TestDisplayStrategy: + """Test base DisplayStrategy class.""" + + def test_base_strategy_not_implemented(self): + """Test base class raises NotImplementedError.""" + strategy = DisplayStrategy() + with pytest.raises(NotImplementedError): + strategy.execute_monitoring(None, None, None) + + +@pytest.fixture +def app_session(): + """Fixture that creates and manages an app session for prompt_toolkit.""" + with create_app_session(output=DummyOutput()) as session: + yield session + + +@pytest.fixture +def mock_display(): + """Fixture that creates a mock display for testing.""" + + async def mock_run_async(): + await asyncio.sleep(0.01) + + display = Mock() + display.display = Mock() + display.run = Mock(return_value=mock_run_async()) + return display + + +class TestInteractiveDisplayStrategy: + """Test InteractiveDisplayStrategy.""" + + @patch('time.sleep') + def test_execute_with_mock_display( + self, mock_sleep, app_session, mock_display + ): + """Test strategy executes with mocked display.""" + mock_collector = Mock() + mock_collector.get_current_view = Mock( + return_value="Test output {SPINNER}" + ) + + strategy = InteractiveDisplayStrategy( + display=mock_display, use_color=True + ) + + mock_sleep.side_effect = KeyboardInterrupt() + + start_time = time.time() + strategy.execute_monitoring( + mock_collector, start_time, timeout_minutes=1 + ) + + # Verify display was called + assert mock_display.display.called + assert mock_display.run.called + + def test_strategy_uses_provided_color_setting(self): + """Test strategy respects use_color parameter.""" + mock_display = Mock() + + strategy_with_color = InteractiveDisplayStrategy( + display=mock_display, use_color=True + ) + assert strategy_with_color.use_color is True + + strategy_no_color = InteractiveDisplayStrategy( + display=mock_display, use_color=False + ) + assert strategy_no_color.use_color is False + + @patch('time.sleep') + def test_completion_message_on_normal_exit( + self, mock_sleep, app_session, mock_display, capsys + ): + """Test displays completion message when monitoring completes normally.""" + mock_collector = Mock() + mock_collector.get_current_view = Mock(return_value="Resources ready") + + strategy = InteractiveDisplayStrategy( + display=mock_display, use_color=True + ) + + mock_sleep.side_effect = KeyboardInterrupt() + + start_time = time.time() + strategy.execute_monitoring( + mock_collector, start_time, timeout_minutes=1 + ) + + captured = capsys.readouterr() + assert "Monitoring Complete!" in captured.out + assert "Monitoring timed out!" not in captured.out + + @patch('time.sleep') + def test_collector_output_is_displayed( + self, mock_sleep, app_session, mock_display, capsys + ): + """Test that collector output appears in final output.""" + mock_collector = Mock() + unique_output = "LoadBalancer lb-12345 ACTIVE" + mock_collector.get_current_view = Mock(return_value=unique_output) + + strategy = InteractiveDisplayStrategy( + display=mock_display, use_color=True + ) + + mock_sleep.side_effect = KeyboardInterrupt() + + start_time = time.time() + strategy.execute_monitoring( + mock_collector, start_time, timeout_minutes=1 + ) + + captured = capsys.readouterr() + assert unique_output in captured.out + + @patch('time.sleep') + def test_execute_handles_service_inactive( + self, mock_sleep, app_session, mock_display, capsys + ): + """Test strategy handles service inactive error.""" + mock_collector = Mock() + error = ClientError( + error_response={ + 'Error': { + 'Code': 'InvalidParameterException', + 'Message': 'Cannot call DescribeServiceRevisions for a service that is INACTIVE', + } + }, + operation_name='DescribeServiceRevisions', + ) + mock_collector.get_current_view = Mock(side_effect=error) + + strategy = InteractiveDisplayStrategy( + display=mock_display, use_color=True + ) + + mock_sleep.side_effect = KeyboardInterrupt() + + start_time = time.time() + strategy.execute_monitoring( + mock_collector, start_time, timeout_minutes=1 + ) + + # Strategy should handle the error and set output to "Service is inactive" + captured = capsys.readouterr() + assert "Service is inactive" in captured.out + + @patch('time.sleep') + def test_execute_other_client_errors_propagate( + self, mock_sleep, app_session, mock_display + ): + """Test strategy propagates non-service-inactive ClientErrors.""" + mock_collector = Mock() + error = ClientError( + error_response={ + 'Error': { + 'Code': 'AccessDeniedException', + 'Message': 'Access denied', + } + }, + operation_name='DescribeServiceRevisions', + ) + mock_collector.get_current_view = Mock(side_effect=error) + + strategy = InteractiveDisplayStrategy( + display=mock_display, use_color=True + ) + + mock_sleep.side_effect = KeyboardInterrupt() + + start_time = time.time() + + # Other client errors should propagate + with pytest.raises(ClientError) as exc_info: + strategy.execute_monitoring( + mock_collector, start_time, timeout_minutes=1 + ) + + assert ( + exc_info.value.response['Error']['Code'] == 'AccessDeniedException' + ) + + @patch('time.sleep') + def test_display_cleanup_on_exception( + self, mock_sleep, app_session, mock_display + ): + """Test display app is properly shut down when exception occurs.""" + mock_collector = Mock() + error = ClientError( + error_response={'Error': {'Code': 'ThrottlingException'}}, + operation_name='DescribeServiceRevisions', + ) + mock_collector.get_current_view = Mock(side_effect=error) + + strategy = InteractiveDisplayStrategy( + display=mock_display, use_color=True + ) + mock_sleep.side_effect = KeyboardInterrupt() + + with pytest.raises(ClientError): + strategy.execute_monitoring( + mock_collector, time.time(), timeout_minutes=1 + ) + + # Verify app.exit() was called in finally block despite exception + mock_display.app.exit.assert_called() diff --git a/tests/unit/customizations/ecs/test_monitorexpressgatewayservice.py b/tests/unit/customizations/ecs/test_monitorexpressgatewayservice.py index e0677a8e3451..6edecc02b4bf 100644 --- a/tests/unit/customizations/ecs/test_monitorexpressgatewayservice.py +++ b/tests/unit/customizations/ecs/test_monitorexpressgatewayservice.py @@ -19,12 +19,6 @@ ECSMonitorExpressGatewayService, ) -# Suppress thread exception warnings - tests use KeyboardInterrupt to exit monitoring loops, -# which causes expected exceptions in background threads -pytestmark = pytest.mark.filterwarnings( - "ignore::pytest.PytestUnhandledThreadExceptionWarning" -) - class TestECSMonitorExpressGatewayServiceCommand: """Test the command class through public interface"""