diff --git a/src/mcp_server_reddit/server.py b/src/mcp_server_reddit/server.py index 4345863..c06ce80 100644 --- a/src/mcp_server_reddit/server.py +++ b/src/mcp_server_reddit/server.py @@ -115,7 +115,25 @@ def get_frontpage_posts(self, limit: int = 10) -> list[Post]: def get_subreddit_info(self, subreddit_name: str) -> SubredditInfo: """Get information about a subreddit""" - subr = self.client.p.subreddit.fetch_by_name(subreddit_name) + try: + subr = self.client.p.subreddit.fetch_by_name(subreddit_name) + except KeyError as exc: + # redditwarp's Subreddit model uses bracket access on fields that + # Reddit may have removed from the API (e.g. active_user_count). + # Fall back to a raw API call to get basic subreddit info. + import logging + logging.getLogger(__name__).warning( + "redditwarp model construction failed for r/%s: %s. " + "Using fallback.", + subreddit_name, exc, + ) + root = self.client.request('GET', f'/r/{subreddit_name}/about') + data = root['data'] + return SubredditInfo( + name=data.get('display_name', subreddit_name), + subscriber_count=data.get('subscribers', -1), + description=data.get('public_description'), + ) return SubredditInfo( name=subr.name, subscriber_count=subr.subscriber_count, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..966d0c4 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,87 @@ +"""Tests for mcp-server-reddit server logic.""" + +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest + +from mcp_server_reddit.server import RedditServer, SubredditInfo + + +class TestGetSubredditInfo: + """Test get_subreddit_info with normal and fallback paths.""" + + def _make_server_with_mock_client(self): + """Create a RedditServer with a mocked redditwarp client.""" + with patch.object(RedditServer, '__init__', lambda self: None): + server = RedditServer() + server.client = MagicMock() + return server + + def test_normal_path(self): + """Returns SubredditInfo when redditwarp model construction succeeds.""" + server = self._make_server_with_mock_client() + + mock_subr = MagicMock() + mock_subr.name = 'python' + mock_subr.subscriber_count = 1200000 + mock_subr.public_description = 'News about Python' + server.client.p.subreddit.fetch_by_name.return_value = mock_subr + + result = server.get_subreddit_info('python') + + assert isinstance(result, SubredditInfo) + assert result.name == 'python' + assert result.subscriber_count == 1200000 + assert result.description == 'News about Python' + + def test_fallback_on_keyerror(self): + """Falls back to raw API when redditwarp raises KeyError. + + This happens when Reddit removes fields from the API response + that redditwarp expects (e.g. active_user_count). + """ + server = self._make_server_with_mock_client() + + # Simulate redditwarp KeyError during model construction + server.client.p.subreddit.fetch_by_name.side_effect = KeyError('active_user_count') + + # Mock the raw API fallback + server.client.request.return_value = { + 'data': { + 'display_name': 'python', + 'subscribers': 1200000, + 'public_description': 'News about Python', + } + } + + result = server.get_subreddit_info('python') + + assert isinstance(result, SubredditInfo) + assert result.name == 'python' + assert result.subscriber_count == 1200000 + assert result.description == 'News about Python' + server.client.request.assert_called_once_with('GET', '/r/python/about') + + def test_fallback_with_missing_fields(self): + """Fallback handles missing fields gracefully with defaults.""" + server = self._make_server_with_mock_client() + server.client.p.subreddit.fetch_by_name.side_effect = KeyError('active_user_count') + + # Minimal API response — missing optional fields + server.client.request.return_value = { + 'data': {} + } + + result = server.get_subreddit_info('test_sub') + + assert result.name == 'test_sub' # falls back to argument + assert result.subscriber_count == -1 # default + assert result.description is None + + def test_non_keyerror_exceptions_propagate(self): + """Non-KeyError exceptions are not caught by the fallback.""" + server = self._make_server_with_mock_client() + server.client.p.subreddit.fetch_by_name.side_effect = ConnectionError('network down') + + with pytest.raises(ConnectionError): + server.get_subreddit_info('python')